Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 19875a0

Browse files
committed
Move TensorHandle related code out of ShapedArray.
1 parent 2b36637 commit 19875a0

File tree

2 files changed

+122
-172
lines changed

2 files changed

+122
-172
lines changed

Sources/TensorFlow/Core/ShapedArray.swift

Lines changed: 31 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
// limitations under the License.
1414

1515
import Swift
16-
import CTensorFlow
1716

1817
//===------------------------------------------------------------------------------------------===//
1918
// TensorBuffer
@@ -25,101 +24,44 @@ import CTensorFlow
2524
/// TensorFlow. In either mode, the buffer object owns the memory and will deallocate it on
2625
/// `deinit`.
2726
@usableFromInline
28-
internal final class TensorBuffer<Scalar> {
29-
typealias Shape = [Int]
30-
31-
/// A reference type wrapping a Swift Array.
32-
/// - Note: An array is used as the native storage for `TensorBuffer`. To make in-place mutation
33-
/// possible when the array is stored in an enumeration value, the array must be wrapped in a
34-
/// reference type.
35-
@usableFromInline
36-
final class BoxedArray {
37-
var array: [Scalar]
38-
39-
init(_ array: __owned [Scalar]) {
40-
self.array = array
41-
}
42-
}
43-
44-
enum Allocation {
45-
case native(BoxedArray)
46-
case tensorFlow(CTensor)
47-
}
27+
internal class TensorBuffer<Scalar> {
28+
/// Cached element count of the underlying buffer.
29+
let count : Int
4830

49-
let allocation: Allocation
50-
let count: Int
31+
init(count: Int) { self.count = count }
5132

52-
deinit {
53-
debugLog("De-initializing tensor buffer.")
54-
switch allocation {
55-
case .native:
56-
debugLog("Deallocating underlying buffer.")
57-
case let .tensorFlow(cTensor):
58-
debugLog("Deleting underlying tensor.")
59-
TF_DeleteTensor(cTensor)
60-
}
61-
debugLog("Returning from deinit of TensorBuffer.")
33+
func withUnsafeBufferPointer<R>(
34+
_ body: (UnsafeBufferPointer<Scalar>) throws -> R
35+
) rethrows -> R {
36+
fatalError("withUnsafeBufferPointer unimplemented because TensorBuffer is abstract")
6237
}
6338

64-
init(allocation: Allocation, count: Int) {
65-
self.allocation = allocation
66-
self.count = count
39+
func withUnsafeMutableBufferPointer<R>(
40+
_ body: (inout UnsafeMutableBufferPointer<Scalar>) throws -> R
41+
) rethrows -> R {
42+
fatalError("withUnsafeMutableBufferPointer unimplemented because TensorBuffer is abstract")
6743
}
6844
}
6945

70-
// TF Tensor-specific initializer.
71-
extension TensorBuffer where Scalar: _TensorFlowDataTypeCompatible {
72-
/// Creates a local tensor buffer from a C `TF_Tensor*` value and takes ownership of the value.
73-
convenience init(owning cTensor: CTensor, count: Int) {
74-
debugLog("Initializing TensorBuffer with a cTensor of \(count) elements.")
75-
let actualCount = (0..<TF_NumDims(cTensor)).reduce(1) { accumulator, next in
76-
accumulator * Int(TF_Dim(cTensor, next))
77-
}
78-
assert(actualCount == count)
79-
self.init(allocation: .tensorFlow(cTensor), count: count)
80-
}
81-
}
46+
// TensorBuffer backed by a native swift array.
47+
internal class ArrayTensorBuffer<Scalar> : TensorBuffer<Scalar> {
48+
var array: [Scalar]
8249

83-
// Factory methods.
84-
extension TensorBuffer {
85-
static func create(
86-
count: Int,
87-
withInitializer body: (UnsafeMutableBufferPointer<Scalar>) -> Void
88-
) -> TensorBuffer<Scalar> {
89-
let array = [Scalar](unsafeUninitializedCapacity: count) { buffer, initializedCount in
90-
body(buffer)
91-
initializedCount = count
92-
}
93-
return TensorBuffer(allocation: .native(BoxedArray(array)), count: count)
50+
init(_ array: __owned [Scalar]) {
51+
self.array = array
52+
super.init(count: array.count)
9453
}
95-
}
9654

97-
// Unsafe address accessor.
98-
extension TensorBuffer {
99-
func withUnsafeBufferPointer<R>(
55+
override func withUnsafeBufferPointer<R>(
10056
_ body: (UnsafeBufferPointer<Scalar>) throws -> R
10157
) rethrows -> R {
102-
switch allocation {
103-
case let .native(box):
104-
return try box.array.withUnsafeBufferPointer { pointer in try body(pointer) }
105-
case let .tensorFlow(cTensor):
106-
let startAddress = TF_TensorData(cTensor).assumingMemoryBound(to: Scalar.self)
107-
let bufferPointer = UnsafeBufferPointer(start: startAddress, count: count)
108-
return try body(bufferPointer)
109-
}
58+
return try array.withUnsafeBufferPointer(body)
11059
}
11160

112-
func withUnsafeMutableBufferPointer<R>(
61+
override func withUnsafeMutableBufferPointer<R>(
11362
_ body: (inout UnsafeMutableBufferPointer<Scalar>) throws -> R
11463
) rethrows -> R {
115-
switch allocation {
116-
case let .native(box):
117-
return try box.array.withUnsafeMutableBufferPointer { pointer in try body(&pointer) }
118-
case let .tensorFlow(cTensor):
119-
let startAddress = TF_TensorData(cTensor).assumingMemoryBound(to: Scalar.self)
120-
var bufferPointer = UnsafeMutableBufferPointer(start: startAddress, count: count)
121-
return try body(&bufferPointer)
122-
}
64+
return try array.withUnsafeMutableBufferPointer(body)
12365
}
12466
}
12567

@@ -444,46 +386,12 @@ fileprivate extension ShapedArray {
444386
if isKnownUniquelyReferenced(&buffer) { return }
445387
let oldBuffer = buffer
446388
debugLog("Unique reference check")
447-
buffer = TensorBuffer.create(count: scalarCount) { bufferPointer in
448-
let pointer = bufferPointer.baseAddress!
449-
oldBuffer.withUnsafeBufferPointer { oldBufferPointer in
450-
let oldPointer = oldBufferPointer.baseAddress!
451-
pointer.initialize(from: oldPointer, count: scalarCount)
452-
}
389+
buffer = oldBuffer.withUnsafeBufferPointer { oldBufferPointer in
390+
ArrayTensorBuffer<Scalar>([Scalar](oldBufferPointer))
453391
}
454392
}
455393
}
456394

457-
internal extension ShapedArray where Scalar: _TensorFlowDataTypeCompatible {
458-
@usableFromInline
459-
init(owning cTensor: CTensor) {
460-
// Including \(Scalar.self) into the message would cause non-deterministic crashes.
461-
debugLog("Initializing ShapedArray from CTensor.")
462-
shape = (0..<TF_NumDims(cTensor)).map { Int(TF_Dim(cTensor, $0)) }
463-
if _RuntimeConfig.printsDebugLog {
464-
// Without this local variable, passing the string directly into debugLog() would not
465-
// work, because 'self' is captured by the auto closure param in debugLog().
466-
let shapeStr = "The shape is \(shape)."
467-
debugLog(shapeStr)
468-
}
469-
buffer = TensorBuffer(owning: cTensor, count: shape.reduce(1, *))
470-
debugLog("Done initializing ShapedArray from CTensor.")
471-
}
472-
473-
@usableFromInline
474-
@inline(never)
475-
init(cTensorHandle: CTensorHandle) {
476-
let status = TF_NewStatus()
477-
let cTensor = TFE_TensorHandleResolve(cTensorHandle, status)
478-
checkOk(status)
479-
TF_DeleteStatus(status)
480-
internalConsistencyCheck(cTensor != nil)
481-
debugLog("# of dims is \(TF_NumDims(cTensor!))")
482-
debugLog("Returning a shaped array.")
483-
self.init(owning: cTensor!)
484-
}
485-
}
486-
487395
public extension ShapedArray {
488396
/// The number of dimensions of the array.
489397
var rank: Int {
@@ -505,35 +413,24 @@ public extension ShapedArray {
505413
/// - Precondition: The number of scalars must equal the product of the dimensions of the shape.
506414
init(shape: __owned [Int], scalars: __owned [Scalar]) {
507415
precondition(shape.reduce(1, *) == scalars.count, "Scalar count mismatch.")
508-
let buffer = TensorBuffer<Scalar>(allocation: .native(.init(scalars)), count: scalars.count)
416+
let buffer = ArrayTensorBuffer<Scalar>(scalars)
509417
self.init(buffer: buffer, shape: shape)
510418
}
511419

512420
/// Creates a `ShapedArray` with the specified shape and sequence of scalars in row-major order.
513421
/// - Precondition: The number of scalars must equal the product of the dimensions of the shape.
514422
init<S: Sequence>(shape: __owned [Int], scalars: __shared S) where S.Element == Scalar {
515423
let scalarCount = shape.reduce(1, *)
516-
let buffer = TensorBuffer<Scalar>.create(count: scalarCount) { bufferPointer in
517-
let pointer = bufferPointer.baseAddress!
518-
// TODO: Refactor with better pointer initializers in Swift 4.1.
519-
var i = 0
520-
for scalar in scalars {
521-
guard i < scalarCount else { break }
522-
pointer.advanced(by: i).initialize(to: scalar)
523-
i += 1
524-
}
525-
// If the sequence has fewer elements than the shape needs, this is a precondition
526-
// failure.
527-
precondition(
528-
i == scalarCount,
529-
"The sequence has fewer elements than needed by the shape.")
530-
}
424+
let buffer = ArrayTensorBuffer<Scalar>([Scalar](scalars))
425+
precondition(
426+
buffer.count == scalarCount,
427+
"The sequence has fewer elements than needed by the shape.")
531428
self.init(buffer: buffer, shape: shape)
532429
}
533430

534431
/// Creates a `ShapedArray` from a scalar value.
535432
init(_ scalar: __owned Scalar) {
536-
self.init(buffer: TensorBuffer(allocation: .native(.init([scalar])), count: 1), shape: [])
433+
self.init(buffer: ArrayTensorBuffer([scalar]), shape: [])
537434
}
538435

539436
/// Creates a `ShapedArray` with the specified shape and a single, repeated scalar value.
@@ -552,9 +449,7 @@ public extension ShapedArray {
552449
/// - shape: The shape of the `ShapedArray`.
553450
init(repeating repeatedValue: __owned Scalar, shape: __owned [Int]) {
554451
let scalarCount = shape.reduce(1, *)
555-
let buffer = TensorBuffer<Scalar>(
556-
allocation: .native(.init(Array(repeating: repeatedValue, count: scalarCount))),
557-
count: scalarCount)
452+
let buffer = ArrayTensorBuffer<Scalar>(Array(repeating: repeatedValue, count: scalarCount))
558453
self.init(buffer: buffer, shape: shape)
559454
}
560455
}
@@ -656,42 +551,6 @@ public extension ShapedArray {
656551
}
657552
}
658553

659-
// Tensor conversion.
660-
extension ShapedArray where Scalar: TensorFlowScalar {
661-
var byteCount: Int {
662-
return MemoryLayout<Scalar>.stride * scalarCount
663-
}
664-
665-
@usableFromInline
666-
__consuming func makeTensorHandle() -> TensorHandle<Scalar> {
667-
// This initializer is designed to optimize conversion from TF-allocated
668-
// `ShapedArray` instances.
669-
switch buffer.allocation {
670-
case let .native(box):
671-
precondition(
672-
rank <= Int(Int32.max),
673-
"Conversion to TensorHandle is undefined when rank exceeds `Int32.max`.")
674-
precondition(
675-
shape.allSatisfy { $0 <= Int(Int32.max) },
676-
"Conversion to TensorHandle is undefined when shape dimensions exceed `Int32.max`.")
677-
return TensorHandle<Scalar>(
678-
shape: shape,
679-
scalarsInitializer: { addr in
680-
addr.initialize(from: box.array, count: scalarCount)
681-
})
682-
case let .tensorFlow(cTensor):
683-
return TensorHandle(copyingFromCTensor: cTensor)
684-
}
685-
}
686-
}
687-
688-
// Tensor conversion.
689-
public extension Tensor {
690-
init(_ array: __owned ShapedArray<Scalar>) {
691-
self.init(handle: array.makeTensorHandle())
692-
}
693-
}
694-
695554
// Array literal conversion.
696555
extension ShapedArray: ExpressibleByArrayLiteral where Scalar: TensorFlowScalar {
697556
public typealias ArrayLiteralElement = _TensorElementLiteral<Scalar>

Sources/TensorFlow/Core/TensorHandle.swift

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,94 @@ public struct VariantHandle {
212212
self.handle = handle
213213
}
214214
}
215+
216+
//===------------------------------------------------------------------------------------------===//
217+
// TensorBuffer based on a C `TF_Tensor*`.
218+
//===------------------------------------------------------------------------------------------===//
219+
220+
// TF Tensor-specific initializer.
221+
internal class CTensorTensorBuffer<Scalar> : TensorBuffer<Scalar> {
222+
let cTensor: CTensor
223+
224+
/// Creates a local tensor buffer from a C `TF_Tensor*` value and takes ownership of the value.
225+
init(owning cTensor: CTensor, count: Int) {
226+
debugLog("Initializing TensorBuffer with a cTensor of \(count) elements.")
227+
let actualCount = (0..<TF_NumDims(cTensor)).reduce(1) { accumulator, next in
228+
accumulator * Int(TF_Dim(cTensor, next))
229+
}
230+
assert(actualCount == count)
231+
self.cTensor = cTensor
232+
super.init(count: count)
233+
}
234+
235+
override func withUnsafeBufferPointer<R>(
236+
_ body: (UnsafeBufferPointer<Scalar>) throws -> R
237+
) rethrows -> R {
238+
let startAddress = TF_TensorData(cTensor).assumingMemoryBound(to: Scalar.self)
239+
let bufferPointer = UnsafeBufferPointer(start: startAddress, count: count)
240+
return try body(bufferPointer)
241+
}
242+
243+
override func withUnsafeMutableBufferPointer<R>(
244+
_ body: (inout UnsafeMutableBufferPointer<Scalar>) throws -> R
245+
) rethrows -> R {
246+
let startAddress = TF_TensorData(cTensor).assumingMemoryBound(to: Scalar.self)
247+
var bufferPointer = UnsafeMutableBufferPointer(start: startAddress, count: count)
248+
return try body(&bufferPointer)
249+
}
250+
251+
deinit {
252+
TF_DeleteTensor(cTensor)
253+
}
254+
}
255+
256+
internal extension ShapedArray where Scalar: _TensorFlowDataTypeCompatible {
257+
@usableFromInline
258+
init(owning cTensor: CTensor) {
259+
// Including \(Scalar.self) into the message would cause non-deterministic crashes.
260+
debugLog("Initializing ShapedArray from CTensor.")
261+
let shape = (0..<TF_NumDims(cTensor)).map { Int(TF_Dim(cTensor, $0)) }
262+
if _RuntimeConfig.printsDebugLog {
263+
// Without this local variable, passing the string directly into debugLog() would not
264+
// work, because 'self' is captured by the auto closure param in debugLog().
265+
let shapeStr = "The shape is \(shape)."
266+
debugLog(shapeStr)
267+
}
268+
self.init(
269+
buffer: CTensorTensorBuffer<Scalar>(owning: cTensor, count: shape.reduce(1, *)),
270+
shape: shape)
271+
debugLog("Done initializing ShapedArray from CTensor.")
272+
}
273+
274+
@usableFromInline
275+
@inline(never)
276+
init(cTensorHandle: CTensorHandle) {
277+
let status = TF_NewStatus()
278+
let cTensor = TFE_TensorHandleResolve(cTensorHandle, status)
279+
checkOk(status)
280+
TF_DeleteStatus(status)
281+
internalConsistencyCheck(cTensor != nil)
282+
debugLog("# of dims is \(TF_NumDims(cTensor!))")
283+
debugLog("Returning a shaped array.")
284+
self.init(owning: cTensor!)
285+
}
286+
}
287+
288+
// Tensor conversion.
289+
public extension Tensor {
290+
init(_ array: __owned ShapedArray<Scalar>) {
291+
precondition(
292+
array.rank <= Int(Int32.max),
293+
"Conversion to TensorHandle is undefined when rank exceeds `Int32.max`.")
294+
precondition(
295+
array.shape.allSatisfy { $0 <= Int(Int32.max) },
296+
"Conversion to TensorHandle is undefined when shape dimensions exceed `Int32.max`.")
297+
if let buffer = array.buffer as? CTensorTensorBuffer<Scalar> {
298+
self = Tensor(handle: TensorHandle(copyingFromCTensor: buffer.cTensor))
299+
} else {
300+
self = array.buffer.withUnsafeBufferPointer { buffer in
301+
return Tensor(shape: TensorShape(array.shape), scalars: buffer)
302+
}
303+
}
304+
}
305+
}

0 commit comments

Comments
 (0)