Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 620cea6

Browse files
committed
Move TensorHandle related code out of ShapedArray.
1 parent 2b36637 commit 620cea6

File tree

2 files changed

+121
-172
lines changed

2 files changed

+121
-172
lines changed

Sources/TensorFlow/Core/ShapedArray.swift

Lines changed: 30 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
// limitations under the License.
1414

1515
import Swift
16-
import CTensorFlow
1716

1817
//===------------------------------------------------------------------------------------------===//
1918
// TensorBuffer
@@ -25,101 +24,43 @@ import CTensorFlow
2524
/// TensorFlow. In either mode, the buffer object owns the memory and will deallocate it on
2625
/// `deinit`.
2726
@usableFromInline
28-
internal final class TensorBuffer<Scalar> {
29-
typealias Shape = [Int]
30-
31-
/// A reference type wrapping a Swift Array.
32-
/// - Note: An array is used as the native storage for `TensorBuffer`. To make in-place mutation
33-
/// possible when the array is stored in an enumeration value, the array must be wrapped in a
34-
/// reference type.
35-
@usableFromInline
36-
final class BoxedArray {
37-
var array: [Scalar]
38-
39-
init(_ array: __owned [Scalar]) {
40-
self.array = array
41-
}
42-
}
43-
44-
enum Allocation {
45-
case native(BoxedArray)
46-
case tensorFlow(CTensor)
47-
}
27+
internal class TensorBuffer<Scalar> {
28+
let count : Int
4829

49-
let allocation: Allocation
50-
let count: Int
30+
init(count: Int) { self.count = count }
5131

52-
deinit {
53-
debugLog("De-initializing tensor buffer.")
54-
switch allocation {
55-
case .native:
56-
debugLog("Deallocating underlying buffer.")
57-
case let .tensorFlow(cTensor):
58-
debugLog("Deleting underlying tensor.")
59-
TF_DeleteTensor(cTensor)
60-
}
61-
debugLog("Returning from deinit of TensorBuffer.")
32+
func withUnsafeBufferPointer<R>(
33+
_ body: (UnsafeBufferPointer<Scalar>) throws -> R
34+
) rethrows -> R {
35+
fatalError("TensorBuffer should not be constructed directly.")
6236
}
6337

64-
init(allocation: Allocation, count: Int) {
65-
self.allocation = allocation
66-
self.count = count
38+
func withUnsafeMutableBufferPointer<R>(
39+
_ body: (inout UnsafeMutableBufferPointer<Scalar>) throws -> R
40+
) rethrows -> R {
41+
fatalError("TensorBuffer should not be constructed directly.")
6742
}
6843
}
6944

70-
// TF Tensor-specific initializer.
71-
extension TensorBuffer where Scalar: _TensorFlowDataTypeCompatible {
72-
/// Creates a local tensor buffer from a C `TF_Tensor*` value and takes ownership of the value.
73-
convenience init(owning cTensor: CTensor, count: Int) {
74-
debugLog("Initializing TensorBuffer with a cTensor of \(count) elements.")
75-
let actualCount = (0..<TF_NumDims(cTensor)).reduce(1) { accumulator, next in
76-
accumulator * Int(TF_Dim(cTensor, next))
77-
}
78-
assert(actualCount == count)
79-
self.init(allocation: .tensorFlow(cTensor), count: count)
80-
}
81-
}
45+
// TensorBuffer backed by a native swift array.
46+
internal class ArrayTensorBuffer<Scalar> : TensorBuffer<Scalar> {
47+
var array: [Scalar]
8248

83-
// Factory methods.
84-
extension TensorBuffer {
85-
static func create(
86-
count: Int,
87-
withInitializer body: (UnsafeMutableBufferPointer<Scalar>) -> Void
88-
) -> TensorBuffer<Scalar> {
89-
let array = [Scalar](unsafeUninitializedCapacity: count) { buffer, initializedCount in
90-
body(buffer)
91-
initializedCount = count
92-
}
93-
return TensorBuffer(allocation: .native(BoxedArray(array)), count: count)
49+
init(_ array: __owned [Scalar]) {
50+
self.array = array
51+
super.init(count: array.count)
9452
}
95-
}
9653

97-
// Unsafe address accessor.
98-
extension TensorBuffer {
99-
func withUnsafeBufferPointer<R>(
54+
override func withUnsafeBufferPointer<R>(
10055
_ body: (UnsafeBufferPointer<Scalar>) throws -> R
10156
) rethrows -> R {
102-
switch allocation {
103-
case let .native(box):
104-
return try box.array.withUnsafeBufferPointer { pointer in try body(pointer) }
105-
case let .tensorFlow(cTensor):
106-
let startAddress = TF_TensorData(cTensor).assumingMemoryBound(to: Scalar.self)
107-
let bufferPointer = UnsafeBufferPointer(start: startAddress, count: count)
108-
return try body(bufferPointer)
109-
}
57+
return try array.withUnsafeBufferPointer(body)
11058
}
11159

112-
func withUnsafeMutableBufferPointer<R>(
60+
override func withUnsafeMutableBufferPointer<R>(
11361
_ body: (inout UnsafeMutableBufferPointer<Scalar>) throws -> R
11462
) rethrows -> R {
115-
switch allocation {
116-
case let .native(box):
117-
return try box.array.withUnsafeMutableBufferPointer { pointer in try body(&pointer) }
118-
case let .tensorFlow(cTensor):
119-
let startAddress = TF_TensorData(cTensor).assumingMemoryBound(to: Scalar.self)
120-
var bufferPointer = UnsafeMutableBufferPointer(start: startAddress, count: count)
121-
return try body(&bufferPointer)
122-
}
63+
return try array.withUnsafeMutableBufferPointer(body)
12364
}
12465
}
12566

@@ -444,46 +385,12 @@ fileprivate extension ShapedArray {
444385
if isKnownUniquelyReferenced(&buffer) { return }
445386
let oldBuffer = buffer
446387
debugLog("Unique reference check")
447-
buffer = TensorBuffer.create(count: scalarCount) { bufferPointer in
448-
let pointer = bufferPointer.baseAddress!
449-
oldBuffer.withUnsafeBufferPointer { oldBufferPointer in
450-
let oldPointer = oldBufferPointer.baseAddress!
451-
pointer.initialize(from: oldPointer, count: scalarCount)
452-
}
388+
buffer = oldBuffer.withUnsafeBufferPointer { oldBufferPointer in
389+
ArrayTensorBuffer<Scalar>([Scalar](oldBufferPointer))
453390
}
454391
}
455392
}
456393

457-
internal extension ShapedArray where Scalar: _TensorFlowDataTypeCompatible {
458-
@usableFromInline
459-
init(owning cTensor: CTensor) {
460-
// Including \(Scalar.self) into the message would cause non-deterministic crashes.
461-
debugLog("Initializing ShapedArray from CTensor.")
462-
shape = (0..<TF_NumDims(cTensor)).map { Int(TF_Dim(cTensor, $0)) }
463-
if _RuntimeConfig.printsDebugLog {
464-
// Without this local variable, passing the string directly into debugLog() would not
465-
// work, because 'self' is captured by the auto closure param in debugLog().
466-
let shapeStr = "The shape is \(shape)."
467-
debugLog(shapeStr)
468-
}
469-
buffer = TensorBuffer(owning: cTensor, count: shape.reduce(1, *))
470-
debugLog("Done initializing ShapedArray from CTensor.")
471-
}
472-
473-
@usableFromInline
474-
@inline(never)
475-
init(cTensorHandle: CTensorHandle) {
476-
let status = TF_NewStatus()
477-
let cTensor = TFE_TensorHandleResolve(cTensorHandle, status)
478-
checkOk(status)
479-
TF_DeleteStatus(status)
480-
internalConsistencyCheck(cTensor != nil)
481-
debugLog("# of dims is \(TF_NumDims(cTensor!))")
482-
debugLog("Returning a shaped array.")
483-
self.init(owning: cTensor!)
484-
}
485-
}
486-
487394
public extension ShapedArray {
488395
/// The number of dimensions of the array.
489396
var rank: Int {
@@ -505,35 +412,24 @@ public extension ShapedArray {
505412
/// - Precondition: The number of scalars must equal the product of the dimensions of the shape.
506413
init(shape: __owned [Int], scalars: __owned [Scalar]) {
507414
precondition(shape.reduce(1, *) == scalars.count, "Scalar count mismatch.")
508-
let buffer = TensorBuffer<Scalar>(allocation: .native(.init(scalars)), count: scalars.count)
415+
let buffer = ArrayTensorBuffer<Scalar>(scalars)
509416
self.init(buffer: buffer, shape: shape)
510417
}
511418

512419
/// Creates a `ShapedArray` with the specified shape and sequence of scalars in row-major order.
513420
/// - Precondition: The number of scalars must equal the product of the dimensions of the shape.
514421
init<S: Sequence>(shape: __owned [Int], scalars: __shared S) where S.Element == Scalar {
515422
let scalarCount = shape.reduce(1, *)
516-
let buffer = TensorBuffer<Scalar>.create(count: scalarCount) { bufferPointer in
517-
let pointer = bufferPointer.baseAddress!
518-
// TODO: Refactor with better pointer initializers in Swift 4.1.
519-
var i = 0
520-
for scalar in scalars {
521-
guard i < scalarCount else { break }
522-
pointer.advanced(by: i).initialize(to: scalar)
523-
i += 1
524-
}
525-
// If the sequence has fewer elements than the shape needs, this is a precondition
526-
// failure.
527-
precondition(
528-
i == scalarCount,
529-
"The sequence has fewer elements than needed by the shape.")
530-
}
423+
let buffer = ArrayTensorBuffer<Scalar>([Scalar](scalars))
424+
precondition(
425+
buffer.count == scalarCount,
426+
"The sequence has fewer elements than needed by the shape.")
531427
self.init(buffer: buffer, shape: shape)
532428
}
533429

534430
/// Creates a `ShapedArray` from a scalar value.
535431
init(_ scalar: __owned Scalar) {
536-
self.init(buffer: TensorBuffer(allocation: .native(.init([scalar])), count: 1), shape: [])
432+
self.init(buffer: ArrayTensorBuffer([scalar]), shape: [])
537433
}
538434

539435
/// Creates a `ShapedArray` with the specified shape and a single, repeated scalar value.
@@ -552,9 +448,7 @@ public extension ShapedArray {
552448
/// - shape: The shape of the `ShapedArray`.
553449
init(repeating repeatedValue: __owned Scalar, shape: __owned [Int]) {
554450
let scalarCount = shape.reduce(1, *)
555-
let buffer = TensorBuffer<Scalar>(
556-
allocation: .native(.init(Array(repeating: repeatedValue, count: scalarCount))),
557-
count: scalarCount)
451+
let buffer = ArrayTensorBuffer<Scalar>(Array(repeating: repeatedValue, count: scalarCount))
558452
self.init(buffer: buffer, shape: shape)
559453
}
560454
}
@@ -656,42 +550,6 @@ public extension ShapedArray {
656550
}
657551
}
658552

659-
// Tensor conversion.
660-
extension ShapedArray where Scalar: TensorFlowScalar {
661-
var byteCount: Int {
662-
return MemoryLayout<Scalar>.stride * scalarCount
663-
}
664-
665-
@usableFromInline
666-
__consuming func makeTensorHandle() -> TensorHandle<Scalar> {
667-
// This initializer is designed to optimize conversion from TF-allocated
668-
// `ShapedArray` instances.
669-
switch buffer.allocation {
670-
case let .native(box):
671-
precondition(
672-
rank <= Int(Int32.max),
673-
"Conversion to TensorHandle is undefined when rank exceeds `Int32.max`.")
674-
precondition(
675-
shape.allSatisfy { $0 <= Int(Int32.max) },
676-
"Conversion to TensorHandle is undefined when shape dimensions exceed `Int32.max`.")
677-
return TensorHandle<Scalar>(
678-
shape: shape,
679-
scalarsInitializer: { addr in
680-
addr.initialize(from: box.array, count: scalarCount)
681-
})
682-
case let .tensorFlow(cTensor):
683-
return TensorHandle(copyingFromCTensor: cTensor)
684-
}
685-
}
686-
}
687-
688-
// Tensor conversion.
689-
public extension Tensor {
690-
init(_ array: __owned ShapedArray<Scalar>) {
691-
self.init(handle: array.makeTensorHandle())
692-
}
693-
}
694-
695553
// Array literal conversion.
696554
extension ShapedArray: ExpressibleByArrayLiteral where Scalar: TensorFlowScalar {
697555
public typealias ArrayLiteralElement = _TensorElementLiteral<Scalar>

Sources/TensorFlow/Core/TensorHandle.swift

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,94 @@ public struct VariantHandle {
212212
self.handle = handle
213213
}
214214
}
215+
216+
//===------------------------------------------------------------------------------------------===//
217+
// TensorBuffer based on a C `TF_Tensor*`.
218+
//===------------------------------------------------------------------------------------------===//
219+
220+
// TF Tensor-specific initializer.
221+
internal class CTensorTensorBuffer<Scalar> : TensorBuffer<Scalar> {
222+
let cTensor: CTensor
223+
224+
/// Creates a local tensor buffer from a C `TF_Tensor*` value and takes ownership of the value.
225+
init(owning cTensor: CTensor, count: Int) {
226+
debugLog("Initializing TensorBuffer with a cTensor of \(count) elements.")
227+
let actualCount = (0..<TF_NumDims(cTensor)).reduce(1) { accumulator, next in
228+
accumulator * Int(TF_Dim(cTensor, next))
229+
}
230+
assert(actualCount == count)
231+
self.cTensor = cTensor
232+
super.init(count: count)
233+
}
234+
235+
override func withUnsafeBufferPointer<R>(
236+
_ body: (UnsafeBufferPointer<Scalar>) throws -> R
237+
) rethrows -> R {
238+
let startAddress = TF_TensorData(cTensor).assumingMemoryBound(to: Scalar.self)
239+
let bufferPointer = UnsafeBufferPointer(start: startAddress, count: count)
240+
return try body(bufferPointer)
241+
}
242+
243+
override func withUnsafeMutableBufferPointer<R>(
244+
_ body: (inout UnsafeMutableBufferPointer<Scalar>) throws -> R
245+
) rethrows -> R {
246+
let startAddress = TF_TensorData(cTensor).assumingMemoryBound(to: Scalar.self)
247+
var bufferPointer = UnsafeMutableBufferPointer(start: startAddress, count: count)
248+
return try body(&bufferPointer)
249+
}
250+
251+
deinit {
252+
TF_DeleteTensor(cTensor)
253+
}
254+
}
255+
256+
internal extension ShapedArray where Scalar: _TensorFlowDataTypeCompatible {
257+
@usableFromInline
258+
init(owning cTensor: CTensor) {
259+
// Including \(Scalar.self) into the message would cause non-deterministic crashes.
260+
debugLog("Initializing ShapedArray from CTensor.")
261+
let shape = (0..<TF_NumDims(cTensor)).map { Int(TF_Dim(cTensor, $0)) }
262+
if _RuntimeConfig.printsDebugLog {
263+
// Without this local variable, passing the string directly into debugLog() would not
264+
// work, because 'self' is captured by the auto closure param in debugLog().
265+
let shapeStr = "The shape is \(shape)."
266+
debugLog(shapeStr)
267+
}
268+
self.init(
269+
buffer: CTensorTensorBuffer<Scalar>(owning: cTensor, count: shape.reduce(1, *)),
270+
shape: shape)
271+
debugLog("Done initializing ShapedArray from CTensor.")
272+
}
273+
274+
@usableFromInline
275+
@inline(never)
276+
init(cTensorHandle: CTensorHandle) {
277+
let status = TF_NewStatus()
278+
let cTensor = TFE_TensorHandleResolve(cTensorHandle, status)
279+
checkOk(status)
280+
TF_DeleteStatus(status)
281+
internalConsistencyCheck(cTensor != nil)
282+
debugLog("# of dims is \(TF_NumDims(cTensor!))")
283+
debugLog("Returning a shaped array.")
284+
self.init(owning: cTensor!)
285+
}
286+
}
287+
288+
// Tensor conversion.
289+
public extension Tensor {
290+
init(_ array: __owned ShapedArray<Scalar>) {
291+
precondition(
292+
array.rank <= Int(Int32.max),
293+
"Conversion to TensorHandle is undefined when rank exceeds `Int32.max`.")
294+
precondition(
295+
array.shape.allSatisfy { $0 <= Int(Int32.max) },
296+
"Conversion to TensorHandle is undefined when shape dimensions exceed `Int32.max`.")
297+
if let buffer = array.buffer as? CTensorTensorBuffer<Scalar> {
298+
self = Tensor(handle: TensorHandle(copyingFromCTensor: buffer.cTensor))
299+
} else {
300+
self = array.buffer.withUnsafeBufferPointer { buffer in
301+
return Tensor(shape: TensorShape(array.shape), scalars: buffer)
302+
}
303+
}
304+
}
305+
}

0 commit comments

Comments
 (0)