Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 175a172

Browse files
committed
Shape inference in LazyTensor.
1 parent 2f75a2d commit 175a172

File tree

4 files changed

+151
-2
lines changed

4 files changed

+151
-2
lines changed

Sources/TensorFlow/Core/LazyTensorOperation.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ class LazyTensorOperation: TensorOperation {
170170
let outputCount: Int
171171
var inputs: [Input]
172172
var attributes: [String: Attribute]
173+
var outputShapes: [TensorShape]
173174
var deviceName: String?
174175
var outputs: [TFETensorHandle]?
175176
var id: String?
@@ -214,6 +215,7 @@ class LazyTensorOperation: TensorOperation {
214215
self.attributes = [:]
215216
self.deviceName = _ExecutionContext.global.currentDeviceName
216217
self.outputCount = outputCount
218+
self.outputShapes = []
217219
self.outputs = nil
218220
self.id = id
219221
LazyTensorOperation.liveOperations += 1
@@ -228,6 +230,7 @@ class LazyTensorOperation: TensorOperation {
228230
}
229231

230232
func evaluate() -> [LazyTensorHandle] {
233+
updateOutputShapes()
231234
return (0..<outputCount).map {
232235
LazyTensorHandle(_lazyLive: self, index: $0)
233236
}
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import CTensorFlow
2+
3+
extension TFETensorHandle {
4+
var rank: Int {
5+
let status = _ExecutionContext.global.status
6+
let rank = TFE_TensorHandleNumDims(self._cTensorHandle, status)
7+
checkOk(status)
8+
return Int(rank)
9+
}
10+
11+
var shape: TensorShape {
12+
let status = _ExecutionContext.global.status
13+
let dims: [Int] = (0..<Int32(rank)).map { i in
14+
let dim = TFE_TensorHandleDim(self._cTensorHandle, i, status)
15+
checkOk(status)
16+
return Int(dim)
17+
}
18+
return TensorShape(dims)
19+
}
20+
}
21+
22+
extension LazyTensorHandle {
23+
var shape: TensorShape {
24+
switch handle {
25+
case .concrete(let h, _): return h.shape
26+
case .symbolic(let op, let index, _): return op.outputShapes[index]
27+
}
28+
}
29+
}
30+
31+
extension TensorShape {
32+
var tfShape: TF_ShapeAndType {
33+
let int64_dimensions = dimensions.map { Int64($0) }
34+
let cArray = UnsafeMutableBufferPointer<Int64>.allocate(capacity: rank)
35+
let _ = cArray.initialize(from: int64_dimensions)
36+
return TF_ShapeAndType(
37+
num_dims: Int32(rank),
38+
dims: cArray.baseAddress,
39+
dtype: TF_DataType(rawValue: 0) /*TODO*/
40+
)
41+
}
42+
}
43+
44+
extension LazyTensorOperation {
45+
private var tfeOp: TFE_Op {
46+
let op = TFE_Op(name, outputCount)
47+
for (name, value) in attributes {
48+
switch value {
49+
case .boolValue(let v): op.updateAttribute(name, v)
50+
case .intValue(let v): op.updateAttribute(name, v)
51+
case .floatValue(let v): op.updateAttribute(name, v)
52+
case .doubleValue(let v): op.updateAttribute(name, v)
53+
case .stringValue(let v): op.updateAttribute(name, v)
54+
case .boolArray(let v): op.updateAttribute(name, v)
55+
case .intArray(let v): op.updateAttribute(name, v)
56+
case .floatArray(let v): op.updateAttribute(name, v)
57+
case .doubleArray(let v): op.updateAttribute(name, v)
58+
case .stringArray(let v): op.updateAttribute(name, v)
59+
case .constTensor(_): fatalError("Const Tensor cannot be eager attribute.")
60+
case .tensorDataTypeValue(let v): op.updateAttribute(name, v)
61+
case .tensorDataTypeArray(let v): op.updateAttribute(name, v)
62+
case .optionalTensorShape(let v): op.updateAttribute(name, v)
63+
case .optionalTensorShapeArray(let v): op.updateAttribute(name, v)
64+
case .tensorFunctionPointer(_): fatalError("tensorFunctionPointer Unimplemented!")
65+
}
66+
}
67+
return op
68+
}
69+
70+
private func updateOutputShapes() {
71+
let status = TF_NewStatus()
72+
defer { TF_DeleteStatus(status) }
73+
74+
let inputShapes: [TensorShape] = inputs.map {
75+
switch $0 {
76+
case .single(let handle): return handle.shape
77+
case .list(_): fatalError("Unimplemented")
78+
}
79+
}
80+
81+
let inputShapeList = TF_NewShapeAndTypeList(/*num_shapes*/ Int32(inputShapes.count))
82+
for (i, shape) in inputShapes.enumerated() {
83+
let int64_dimensions = shape.dimensions.map { Int64($0) }
84+
int64_dimensions.withUnsafeBufferPointer { buffer in
85+
TF_ShapeAndTypeListSetShape(
86+
inputShapeList,
87+
/*index*/ Int32(i),
88+
buffer.baseAddress,
89+
Int32(int64_dimensions.count))
90+
}
91+
}
92+
93+
// This will be filled in by TFE_InferShapes.
94+
var outputShapeList = UnsafeMutablePointer<TF_ShapeAndTypeList>(nil)
95+
let tfeOp = self.tfeOp
96+
defer {
97+
TFE_DeleteOp(tfeOp.op)
98+
TF_DeleteStatus(tfeOp.status)
99+
}
100+
101+
TFE_InferShapes(
102+
tfeOp.op,
103+
/*input_shapes*/ inputShapeList,
104+
/*input_tensors*/ nil,
105+
/*num_input_tensors*/ 0,
106+
/*input_tensors_as_shapes*/ nil,
107+
/*input_resource_shapes_and_types*/ nil,
108+
/*output_shapes*/ &outputShapeList,
109+
/*output_resource_shapes_and_types*/ nil,
110+
status)
111+
112+
outputShapes = (0..<outputShapeList!.pointee.num_items).map { index -> TensorShape in
113+
let outputShape = outputShapeList!.pointee.items![Int(index)]
114+
let dims = (0..<outputShape.num_dims).map { Int(outputShape.dims![Int($0)]) }
115+
return TensorShape(dims)
116+
}
117+
for (i, outputShape) in outputShapes.enumerated() {
118+
print ("\(i): \(outputShape)")
119+
}
120+
121+
TF_DeleteShapeAndTypeList(inputShapeList)
122+
TF_DeleteShapeAndTypeList(outputShapeList)
123+
}
124+
125+
}

Sources/TensorFlow/Core/Tensor.swift

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,15 @@ extension Tensor: AnyTensor {
5151

5252
public extension Tensor {
5353
/// The number of dimensions of the `Tensor`.
54-
@inlinable
54+
// @inlinable
5555
var rank: Int {
5656
@_semantics("autodiff.nonvarying")
5757
get {
58+
if let lazyHandle = handle.handle as? LazyTensorHandle {
59+
if case let .symbolic(op, index, _) = lazyHandle.handle {
60+
return op.outputShapes[index].rank
61+
}
62+
}
5863
let status = _ExecutionContext.global.status
5964
let rank = TFE_TensorHandleNumDims(handle._cTensorHandle, status)
6065
checkOk(status)
@@ -63,10 +68,15 @@ public extension Tensor {
6368
}
6469

6570
/// The shape of the `Tensor`.
66-
@inlinable
71+
// @inlinable
6772
var shape: TensorShape {
6873
@_semantics("autodiff.nonvarying")
6974
get {
75+
if let lazyHandle = handle.handle as? LazyTensorHandle {
76+
if case let .symbolic(op, index, _) = lazyHandle.handle {
77+
return op.outputShapes[index]
78+
}
79+
}
7080
let status = _ExecutionContext.global.status
7181
let dims: [Int] = (0..<Int32(rank)).map { i in
7282
let dim = TFE_TensorHandleDim(self.handle._cTensorHandle, i, status)

Tests/TensorFlowTests/TestShape.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import TensorFlow
2+
_RuntimeConfig.useLazyTensor = true
3+
// _RuntimeConfig.printsDebugLog = true
4+
let a = Tensor<Float>(shape: [3, 1], scalars: [1.0, 2.0, 3.0])
5+
let b = Tensor<Float>(shape: [1, 3], scalars: [1.0, 2.0, 3.0])
6+
let c = Tensor<Float>(10.0)
7+
let w = a * b
8+
print ("\(w.shape)")
9+
let x = w * c
10+
print("\(x)")
11+
print("\(w)")

0 commit comments

Comments
 (0)