Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 6967a7b

Browse files
committed
First cut of shape inference in LazyTensor.
1 parent 4291d11 commit 6967a7b

File tree

4 files changed

+143
-5
lines changed

4 files changed

+143
-5
lines changed

Sources/TensorFlow/Core/LazyTensorContext.swift

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,22 @@ class LazyTensorOperationsTracker {
6262

6363
struct LazyTensorContext {
6464
private var operationsTracker = LazyTensorOperationsTracker()
65+
private var _shapeTrackingEnabled = true
6566

6667
static private var threadLocalContext: LazyTensorContext {
67-
_ThreadLocalState.local.lazyTensorContext
68+
_read { yield _ThreadLocalState.local.lazyTensorContext }
69+
_modify { yield &_ThreadLocalState.local.lazyTensorContext }
6870
}
6971

7072
static var operationsTracker: LazyTensorOperationsTracker {
7173
return threadLocalContext.operationsTracker
7274
}
75+
76+
/// A flag that determines whether we should track shapes.
77+
/// We will need to disable shape tracking within certain contexts.
78+
/// e.g., we won't be able to compute shapes when tracing.
79+
static var shapeTrackingEnabled: Bool {
80+
get { threadLocalContext._shapeTrackingEnabled }
81+
set { threadLocalContext._shapeTrackingEnabled = newValue }
82+
}
7383
}

Sources/TensorFlow/Core/LazyTensorOperation.swift

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,29 @@ class LazyTensorHandle: _AnyTensorHandle {
6767
}
6868

6969
/// The number of dimensions of the underlying `Tensor`.
70-
@inlinable
71-
var rank: Int { _tfeTensorHandle.rank }
70+
@usableFromInline
71+
var rank: Int {
72+
@_semantics("autodiff.nonvarying")
73+
get { return shape.rank }
74+
}
7275

7376
/// The shape of the underlying `Tensor`.
74-
@inlinable
75-
var shape: TensorShape { _tfeTensorHandle.shape }
77+
@usableFromInline
78+
var shape: TensorShape {
79+
@_semantics("autodiff.nonvarying")
80+
get {
81+
switch handle {
82+
case .symbolic(let op, let index, _):
83+
precondition(LazyTensorContext.shapeTrackingEnabled,
84+
"Shape tracking is not enabled in this context.")
85+
if let shape = op.outputShapes[index] { return shape }
86+
// Materialize and get the shape from concrete tensor handle.
87+
op.outputShapes[index] = _tfeTensorHandle.shape
88+
return op.outputShapes[index]!
89+
case .concrete(let tfeHandle, _): return tfeHandle.shape
90+
}
91+
}
92+
}
7693

7794
/// Returns the underlying `LazyTensorOperation` if this is a symbolic `LazyTensorHandle`.
7895
var lazyTensorOperation: LazyTensorOperation? {
@@ -186,6 +203,7 @@ class LazyTensorOperation: TensorOperation {
186203
let outputCount: Int
187204
var inputs: [Input]
188205
var attributes: [String: Attribute]
206+
var outputShapes: [TensorShape?]
189207
var deviceName: String?
190208
var outputs: [TFETensorHandle]?
191209
var id: String?
@@ -230,6 +248,7 @@ class LazyTensorOperation: TensorOperation {
230248
self.attributes = [:]
231249
self.deviceName = _ExecutionContext.global.currentDeviceName
232250
self.outputCount = outputCount
251+
self.outputShapes = []
233252
self.outputs = nil
234253
self.id = id
235254
LazyTensorOperation.liveOperations += 1
@@ -244,6 +263,9 @@ class LazyTensorOperation: TensorOperation {
244263
}
245264

246265
func evaluate() -> [LazyTensorHandle] {
266+
if LazyTensorContext.shapeTrackingEnabled {
267+
updateOutputShapes()
268+
}
247269
return (0..<outputCount).map {
248270
LazyTensorHandle(_lazyLive: self, index: $0)
249271
}
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
import CTensorFlow
15+
16+
extension LazyTensorOperation {
17+
/// Returns a newly created TFE_Op with only the attributes set. NOTE: the
18+
/// caller should explicitly call `TFE_DeleteOp(tfeOp.op)` and
19+
/// `TFE_DeleteStatus(tfeOp.status)` to free the resources allocated in the
20+
/// newly created TFE_Op.
21+
private var tfeOp: TFE_Op {
22+
let op = TFE_Op(name, outputCount)
23+
for (name, value) in attributes {
24+
switch value {
25+
case .boolValue(let v): op.updateAttribute(name, v)
26+
case .intValue(let v): op.updateAttribute(name, v)
27+
case .floatValue(let v): op.updateAttribute(name, v)
28+
case .doubleValue(let v): op.updateAttribute(name, v)
29+
case .stringValue(let v): op.updateAttribute(name, v)
30+
case .boolArray(let v): op.updateAttribute(name, v)
31+
case .intArray(let v): op.updateAttribute(name, v)
32+
case .floatArray(let v): op.updateAttribute(name, v)
33+
case .doubleArray(let v): op.updateAttribute(name, v)
34+
case .stringArray(let v): op.updateAttribute(name, v)
35+
case .constTensor(_): fatalError("Const Tensor cannot be eager attribute.")
36+
case .tensorDataTypeValue(let v): op.updateAttribute(name, v)
37+
case .tensorDataTypeArray(let v): op.updateAttribute(name, v)
38+
case .optionalTensorShape(let v): op.updateAttribute(name, v)
39+
case .optionalTensorShapeArray(let v): op.updateAttribute(name, v)
40+
case .tensorFunctionPointer(_): fatalError("tensorFunctionPointer Unimplemented!")
41+
}
42+
}
43+
return op
44+
}
45+
46+
func updateOutputShapes() {
47+
let status = TF_NewStatus()
48+
defer { TF_DeleteStatus(status) }
49+
50+
let inputShapes: [TensorShape] = inputs.lazy.flatMap { (input: Input) -> [TensorShape] in
51+
switch input {
52+
case .single(let handle): return [handle.shape]
53+
case .list(let values): return values.lazy.map { $0.shape }
54+
}
55+
}
56+
let inputShapeList = TF_NewShapeAndTypeList(/*num_shapes*/ Int32(inputShapes.count))
57+
defer { TF_DeleteShapeAndTypeList(inputShapeList) }
58+
for (i, shape) in inputShapes.enumerated() {
59+
let int64_dimensions = shape.dimensions.map { Int64($0) }
60+
int64_dimensions.withUnsafeBufferPointer { buffer in
61+
TF_ShapeAndTypeListSetShape(
62+
inputShapeList,
63+
/*index*/ Int32(i),
64+
buffer.baseAddress,
65+
Int32(int64_dimensions.count))
66+
}
67+
}
68+
69+
// This will be filled in by `TFE_InferShapes` and should be freed later.
70+
var outputShapeListPtr = UnsafeMutablePointer<TF_ShapeAndTypeList>(nil)
71+
defer { TF_DeleteShapeAndTypeList(outputShapeListPtr) }
72+
73+
let tfeOp = self.tfeOp
74+
defer {
75+
TFE_DeleteOp(tfeOp.op)
76+
TF_DeleteStatus(tfeOp.status)
77+
}
78+
79+
TFE_InferShapes(
80+
tfeOp.op,
81+
/*input_shapes*/ inputShapeList,
82+
/*input_tensors*/ nil,
83+
/*num_input_tensors*/ 0,
84+
/*input_tensors_as_shapes*/ nil,
85+
/*input_resource_shapes_and_types*/ nil,
86+
/*output_shapes*/ &outputShapeListPtr,
87+
/*output_resource_shapes_and_types*/ nil,
88+
status)
89+
90+
checkOk(status)
91+
92+
precondition(outputShapeListPtr != nil, "TFE_InferShapes returned nil for output shapes")
93+
let outputShapeList = outputShapeListPtr!.pointee
94+
outputShapes = (0..<outputShapeList.num_items).lazy.map { index -> TensorShape? in
95+
let outputShape = outputShapeList.items![Int(index)]
96+
if outputShape.num_dims == -1 { return nil }
97+
let dims = (0..<outputShape.num_dims).lazy.map { Int(outputShape.dims![Int($0)]) }
98+
let hasUnknownDims = dims.contains { $0 == -1 }
99+
return hasUnknownDims ? nil : TensorShape(dims)
100+
}
101+
}
102+
}

Sources/TensorFlow/Core/LazyTensorTrace.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ class LazyTensorTraceBuilder {
9292
/// Returns a trace obtained by tracing the given function.
9393
static func trace<In: TensorGroup, Out: TensorGroup>(_ fn: (In) -> Out) -> LazyTensorTrace {
9494
precondition(_ThreadLocalState.useLazyTensor, "Lazy tensor is not enabled for tracing.")
95+
// Disable shape tracking and reset to original state when done.
96+
let shapeTrackingEnabled = LazyTensorContext.shapeTrackingEnabled
97+
defer { LazyTensorContext.shapeTrackingEnabled = shapeTrackingEnabled }
98+
LazyTensorContext.shapeTrackingEnabled = false
9599

96100
// Set up inputs for running `fn`.
97101
let inputOps = In._typeList.map { Self.makePlaceholder(dataType: $0) }

0 commit comments

Comments
 (0)