Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Shape inference of simple operations in LazyTensor. #388

Merged
merged 5 commits into from
Aug 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions Sources/TensorFlow/Core/LazyTensorContext.swift
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,21 @@ class LazyTensorOperationsTracker {

struct LazyTensorContext {
private var operationsTracker = LazyTensorOperationsTracker()
private var isShapeTrackingEnabled = true

static private var threadLocalContext: LazyTensorContext {
_ThreadLocalState.local.lazyTensorContext
static private var local: LazyTensorContext {
_read { yield _ThreadLocalState.local.lazyTensorContext }
_modify { yield &_ThreadLocalState.local.lazyTensorContext }
}

static var operationsTracker: LazyTensorOperationsTracker {
return threadLocalContext.operationsTracker
return local.operationsTracker
}

/// A flag that determines whether we should track shapes. We will need to disable shape
/// tracking within certain contexts. e.g., we won't be able to compute shapes when tracing.
static var isShapeTrackingEnabled: Bool {
get { local.isShapeTrackingEnabled }
set { local.isShapeTrackingEnabled = newValue }
}
}
31 changes: 27 additions & 4 deletions Sources/TensorFlow/Core/LazyTensorOperation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,29 @@ class LazyTensorHandle: _AnyTensorHandle {
}

/// The number of dimensions of the underlying `Tensor`.
@inlinable
var rank: Int { _tfeTensorHandle.rank }
@usableFromInline
var rank: Int {
@_semantics("autodiff.nonvarying")
get { shape.rank }
}

/// The shape of the underlying `Tensor`.
@inlinable
var shape: TensorShape { _tfeTensorHandle.shape }
@usableFromInline
var shape: TensorShape {
@_semantics("autodiff.nonvarying")
get {
switch handle {
case .symbolic(let op, let index, _):
precondition(LazyTensorContext.isShapeTrackingEnabled,
"Shape tracking is not enabled in this context.")
if let shape = op.outputShapes[index] { return shape }
// Materialize and get the shape from concrete tensor handle.
op.outputShapes[index] = _tfeTensorHandle.shape
return op.outputShapes[index]!
case .concrete(let tfeHandle, _): return tfeHandle.shape
}
}
}

/// Returns the underlying `LazyTensorOperation` if this is a symbolic `LazyTensorHandle`.
var lazyTensorOperation: LazyTensorOperation? {
Expand Down Expand Up @@ -186,6 +203,7 @@ class LazyTensorOperation: TensorOperation {
let outputCount: Int
var inputs: [Input]
var attributes: [String: Attribute]
var outputShapes: [TensorShape?]
var deviceName: String?
var outputs: [TFETensorHandle]?
var id: String?
Expand Down Expand Up @@ -230,6 +248,7 @@ class LazyTensorOperation: TensorOperation {
self.attributes = [:]
self.deviceName = _ExecutionContext.global.currentDeviceName
self.outputCount = outputCount
self.outputShapes = []
self.outputs = nil
self.id = id
LazyTensorOperation.liveOperations += 1
Expand All @@ -244,6 +263,9 @@ class LazyTensorOperation: TensorOperation {
}

func evaluate() -> [LazyTensorHandle] {
if LazyTensorContext.isShapeTrackingEnabled {
updateOutputShapes()
}
return (0..<outputCount).map {
LazyTensorHandle(_lazyLive: self, index: $0)
}
Expand Down Expand Up @@ -869,6 +891,7 @@ extension LazyTensorOperation {
for lazyOp in traceInfo.lazyOperations {
let end = start + lazyOp.outputCount
lazyOp.outputs = Array(allOutputs[start..<end])
lazyOp.outputShapes = lazyOp.outputs!.map { $0.shape }
start = end
}

Expand Down
102 changes: 102 additions & 0 deletions Sources/TensorFlow/Core/LazyTensorShapeInference.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import CTensorFlow

extension LazyTensorOperation {
/// Returns a newly created TFE_Op with only the attributes set. NOTE: the
/// caller should explicitly call `TFE_DeleteOp(tfeOp.op)` and
/// `TFE_DeleteStatus(tfeOp.status)` to free the resources allocated in the
/// newly created TFE_Op.
private var tfeOp: TFE_Op {
let op = TFE_Op(name, outputCount)
for (name, value) in attributes {
switch value {
case .boolValue(let v): op.updateAttribute(name, v)
case .intValue(let v): op.updateAttribute(name, v)
case .floatValue(let v): op.updateAttribute(name, v)
case .doubleValue(let v): op.updateAttribute(name, v)
case .stringValue(let v): op.updateAttribute(name, v)
case .boolArray(let v): op.updateAttribute(name, v)
case .intArray(let v): op.updateAttribute(name, v)
case .floatArray(let v): op.updateAttribute(name, v)
case .doubleArray(let v): op.updateAttribute(name, v)
case .stringArray(let v): op.updateAttribute(name, v)
case .constTensor(_): fatalError("Const Tensor cannot be eager attribute.")
case .tensorDataTypeValue(let v): op.updateAttribute(name, v)
case .tensorDataTypeArray(let v): op.updateAttribute(name, v)
case .optionalTensorShape(let v): op.updateAttribute(name, v)
case .optionalTensorShapeArray(let v): op.updateAttribute(name, v)
case .tensorFunctionPointer(_): fatalError("tensorFunctionPointer Unimplemented!")
}
}
return op
}

func updateOutputShapes() {
let status = TF_NewStatus()
defer { TF_DeleteStatus(status) }

let inputShapes: [TensorShape] = inputs.lazy.flatMap { (input: Input) -> [TensorShape] in
switch input {
case .single(let handle): return [handle.shape]
case .list(let values): return values.lazy.map { $0.shape }
}
}
let inputShapeList = TF_NewShapeAndTypeList(/*num_shapes*/ Int32(inputShapes.count))
defer { TF_DeleteShapeAndTypeList(inputShapeList) }
for (i, shape) in inputShapes.enumerated() {
let int64_dimensions = shape.dimensions.map { Int64($0) }
int64_dimensions.withUnsafeBufferPointer { buffer in
TF_ShapeAndTypeListSetShape(
inputShapeList,
/*index*/ Int32(i),
buffer.baseAddress,
Int32(int64_dimensions.count))
}
}

// This will be filled in by `TFE_InferShapes` and should be freed later.
var outputShapeListPtr = UnsafeMutablePointer<TF_ShapeAndTypeList>(nil)
defer { TF_DeleteShapeAndTypeList(outputShapeListPtr) }

let tfeOp = self.tfeOp
defer {
TFE_DeleteOp(tfeOp.op)
TF_DeleteStatus(tfeOp.status)
}

TFE_InferShapes(
tfeOp.op,
/*input_shapes*/ inputShapeList,
/*input_tensors*/ nil,
/*num_input_tensors*/ 0,
/*input_tensors_as_shapes*/ nil,
/*input_resource_shapes_and_types*/ nil,
/*output_shapes*/ &outputShapeListPtr,
/*output_resource_shapes_and_types*/ nil,
status)

checkOk(status)

precondition(outputShapeListPtr != nil, "TFE_InferShapes returned nil for output shapes")
let outputShapeList = outputShapeListPtr!.pointee
outputShapes = (0..<outputShapeList.num_items).lazy.map { index -> TensorShape? in
let outputShape = outputShapeList.items![Int(index)]
if outputShape.num_dims == -1 { return nil }
let dims = (0..<outputShape.num_dims).lazy.map { Int(outputShape.dims![Int($0)]) }
let hasUnknownDims = dims.contains { $0 == -1 }
return hasUnknownDims ? nil : TensorShape(dims)
}
}
}
4 changes: 4 additions & 0 deletions Sources/TensorFlow/Core/LazyTensorTrace.swift
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ class LazyTensorTraceBuilder {
/// Returns a trace obtained by tracing the given function.
static func trace<In: TensorGroup, Out: TensorGroup>(_ fn: (In) -> Out) -> LazyTensorTrace {
precondition(_ThreadLocalState.useLazyTensor, "Lazy tensor is not enabled for tracing.")
// Disable shape tracking and reset to original state when done.
let isShapeTrackingEnabled = LazyTensorContext.isShapeTrackingEnabled
defer { LazyTensorContext.isShapeTrackingEnabled = isShapeTrackingEnabled }
LazyTensorContext.isShapeTrackingEnabled = false

// Set up inputs for running `fn`.
let inputOps = In._typeList.map { Self.makePlaceholder(dataType: $0) }
Expand Down
71 changes: 71 additions & 0 deletions Tests/TensorFlowTests/LazyTensorShapeInferenceTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import XCTest
@testable import TensorFlow
import CTensorFlow

extension LazyTensorOperation {
/// Returns true if the outputs have been materialized.
var isMaterialized: Bool { outputs != nil }
}

final class LazyTensorShapeInferenceTests: XCTestCase {
override class func setUp() {
super.setUp()
_ThreadLocalState.useLazyTensor = true
}

override class func tearDown() {
super.tearDown()
_ThreadLocalState.useLazyTensor = false
}

func testSimpleShapeComputations() {
let a = Tensor<Float>(shape: [3, 1], scalars: [1.0, 2.0, 3.0])
let b = Tensor<Float>(shape: [1, 3], scalars: [1.0, 2.0, 3.0])
let c = Tensor<Float>(shape: [1, 3], scalars: [4.0, 5.0, 6.0])
let w = a * b
let wLazyTensorOperation = w._lazyTensor!.lazyTensorOperation!
let x = w * c
let xLazyTensorOperation = x._lazyTensor!.lazyTensorOperation!

// Make sure that `w` and `x` are not materialized.
XCTAssertFalse(wLazyTensorOperation.isMaterialized)
XCTAssertFalse(xLazyTensorOperation.isMaterialized)

// Examine shape of w and confirm no materialization has happened.
let wShape = w.shape
XCTAssertEqual(wShape.rank, 2)
XCTAssertEqual(wShape.dimensions, [3, 3])
XCTAssertFalse(wLazyTensorOperation.isMaterialized)
XCTAssertFalse(xLazyTensorOperation.isMaterialized)

let xShape = x.shape
XCTAssertEqual(xShape.rank, 2)
XCTAssertEqual(xShape.dimensions, [3, 3])
XCTAssertFalse(wLazyTensorOperation.isMaterialized)
XCTAssertFalse(xLazyTensorOperation.isMaterialized)

// Trigger materialization.
let _ = x._rawTensorHandle
XCTAssertTrue(wLazyTensorOperation.isMaterialized)
XCTAssertTrue(xLazyTensorOperation.isMaterialized)
}

static var allTests = [
("testSimpleShapeComputations", testSimpleShapeComputations)
]
}

1 change: 1 addition & 0 deletions Tests/TensorFlowTests/XCTestManifests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public func allTests() -> [XCTestCaseEntry] {
testCase(LazyTensorTraceTests.allTests),
testCase(LazyTensorExplicitTraceTests.allTests),
testCase(LazyTensorOperationTests.allTests),
testCase(LazyTensorShapeInferenceTests.allTests),
testCase(LazyTensorTFFunctionBuilderTests.allTests),
testCase(LazyTensorEvaluationTests.allTests),
testCase(LossTests.allTests),
Expand Down