13
13
// limitations under the License.
14
14
15
15
import Swift
16
- import CTensorFlow
17
16
18
17
//===------------------------------------------------------------------------------------------===//
19
18
// TensorBuffer
@@ -25,101 +24,44 @@ import CTensorFlow
25
24
/// TensorFlow. In either mode, the buffer object owns the memory and will deallocate it on
26
25
/// `deinit`.
27
26
@usableFromInline
28
- internal final class TensorBuffer < Scalar> {
29
- typealias Shape = [ Int ]
30
-
31
- /// A reference type wrapping a Swift Array.
32
- /// - Note: An array is used as the native storage for `TensorBuffer`. To make in-place mutation
33
- /// possible when the array is stored in an enumeration value, the array must be wrapped in a
34
- /// reference type.
35
- @usableFromInline
36
- final class BoxedArray {
37
- var array : [ Scalar ]
38
-
39
- init ( _ array: __owned [ Scalar ] ) {
40
- self . array = array
41
- }
42
- }
43
-
44
- enum Allocation {
45
- case native( BoxedArray )
46
- case tensorFlow( CTensor )
47
- }
27
+ internal class TensorBuffer < Scalar> {
28
+ /// Cached element count of the underlying buffer.
29
+ let count : Int
48
30
49
- let allocation : Allocation
50
- let count : Int
31
+ init ( count: Int ) { self . count = count }
51
32
52
- deinit {
53
- debugLog ( " De-initializing tensor buffer. " )
54
- switch allocation {
55
- case . native:
56
- debugLog ( " Deallocating underlying buffer. " )
57
- case let . tensorFlow( cTensor) :
58
- debugLog ( " Deleting underlying tensor. " )
59
- TF_DeleteTensor ( cTensor)
60
- }
61
- debugLog ( " Returning from deinit of TensorBuffer. " )
33
+ func withUnsafeBufferPointer< R> (
34
+ _ body: ( UnsafeBufferPointer < Scalar > ) throws -> R
35
+ ) rethrows -> R {
36
+ fatalError ( " withUnsafeBufferPointer unimplemented because TensorBuffer is abstract " )
62
37
}
63
38
64
- init ( allocation: Allocation , count: Int ) {
65
- self . allocation = allocation
66
- self . count = count
39
+ func withUnsafeMutableBufferPointer< R> (
40
+ _ body: ( inout UnsafeMutableBufferPointer < Scalar > ) throws -> R
41
+ ) rethrows -> R {
42
+ fatalError ( " withUnsafeMutableBufferPointer unimplemented because TensorBuffer is abstract " )
67
43
}
68
44
}
69
45
70
- // TF Tensor-specific initializer.
71
- extension TensorBuffer where Scalar: _TensorFlowDataTypeCompatible {
72
- /// Creates a local tensor buffer from a C `TF_Tensor*` value and takes ownership of the value.
73
- convenience init ( owning cTensor: CTensor , count: Int ) {
74
- debugLog ( " Initializing TensorBuffer with a cTensor of \( count) elements. " )
75
- let actualCount = ( 0 ..< TF_NumDims ( cTensor) ) . reduce ( 1 ) { accumulator, next in
76
- accumulator * Int( TF_Dim ( cTensor, next) )
77
- }
78
- assert ( actualCount == count)
79
- self . init ( allocation: . tensorFlow( cTensor) , count: count)
80
- }
81
- }
46
+ // TensorBuffer backed by a native swift array.
47
+ internal class ArrayTensorBuffer < Scalar> : TensorBuffer < Scalar > {
48
+ var array : [ Scalar ]
82
49
83
- // Factory methods.
84
- extension TensorBuffer {
85
- static func create(
86
- count: Int ,
87
- withInitializer body: ( UnsafeMutableBufferPointer < Scalar > ) -> Void
88
- ) -> TensorBuffer < Scalar > {
89
- let array = [ Scalar] ( unsafeUninitializedCapacity: count) { buffer, initializedCount in
90
- body ( buffer)
91
- initializedCount = count
92
- }
93
- return TensorBuffer ( allocation: . native( BoxedArray ( array) ) , count: count)
50
+ init ( _ array: __owned [ Scalar ] ) {
51
+ self . array = array
52
+ super. init ( count: array. count)
94
53
}
95
- }
96
54
97
- // Unsafe address accessor.
98
- extension TensorBuffer {
99
- func withUnsafeBufferPointer< R> (
55
+ override func withUnsafeBufferPointer< R> (
100
56
_ body: ( UnsafeBufferPointer < Scalar > ) throws -> R
101
57
) rethrows -> R {
102
- switch allocation {
103
- case let . native( box) :
104
- return try box. array. withUnsafeBufferPointer { pointer in try body ( pointer) }
105
- case let . tensorFlow( cTensor) :
106
- let startAddress = TF_TensorData ( cTensor) . assumingMemoryBound ( to: Scalar . self)
107
- let bufferPointer = UnsafeBufferPointer ( start: startAddress, count: count)
108
- return try body ( bufferPointer)
109
- }
58
+ return try array. withUnsafeBufferPointer ( body)
110
59
}
111
60
112
- func withUnsafeMutableBufferPointer< R> (
61
+ override func withUnsafeMutableBufferPointer< R> (
113
62
_ body: ( inout UnsafeMutableBufferPointer < Scalar > ) throws -> R
114
63
) rethrows -> R {
115
- switch allocation {
116
- case let . native( box) :
117
- return try box. array. withUnsafeMutableBufferPointer { pointer in try body ( & pointer) }
118
- case let . tensorFlow( cTensor) :
119
- let startAddress = TF_TensorData ( cTensor) . assumingMemoryBound ( to: Scalar . self)
120
- var bufferPointer = UnsafeMutableBufferPointer ( start: startAddress, count: count)
121
- return try body ( & bufferPointer)
122
- }
64
+ return try array. withUnsafeMutableBufferPointer ( body)
123
65
}
124
66
}
125
67
@@ -444,46 +386,12 @@ fileprivate extension ShapedArray {
444
386
if isKnownUniquelyReferenced ( & buffer) { return }
445
387
let oldBuffer = buffer
446
388
debugLog ( " Unique reference check " )
447
- buffer = TensorBuffer . create ( count: scalarCount) { bufferPointer in
448
- let pointer = bufferPointer. baseAddress!
449
- oldBuffer. withUnsafeBufferPointer { oldBufferPointer in
450
- let oldPointer = oldBufferPointer. baseAddress!
451
- pointer. initialize ( from: oldPointer, count: scalarCount)
452
- }
389
+ buffer = oldBuffer. withUnsafeBufferPointer { oldBufferPointer in
390
+ ArrayTensorBuffer < Scalar > ( [ Scalar] ( oldBufferPointer) )
453
391
}
454
392
}
455
393
}
456
394
457
- internal extension ShapedArray where Scalar: _TensorFlowDataTypeCompatible {
458
- @usableFromInline
459
- init ( owning cTensor: CTensor ) {
460
- // Including \(Scalar.self) into the message would cause non-deterministic crashes.
461
- debugLog ( " Initializing ShapedArray from CTensor. " )
462
- shape = ( 0 ..< TF_NumDims ( cTensor) ) . map { Int ( TF_Dim ( cTensor, $0) ) }
463
- if _RuntimeConfig. printsDebugLog {
464
- // Without this local variable, passing the string directly into debugLog() would not
465
- // work, because 'self' is captured by the auto closure param in debugLog().
466
- let shapeStr = " The shape is \( shape) . "
467
- debugLog ( shapeStr)
468
- }
469
- buffer = TensorBuffer ( owning: cTensor, count: shape. reduce ( 1 , * ) )
470
- debugLog ( " Done initializing ShapedArray from CTensor. " )
471
- }
472
-
473
- @usableFromInline
474
- @inline ( never)
475
- init ( cTensorHandle: CTensorHandle ) {
476
- let status = TF_NewStatus ( )
477
- let cTensor = TFE_TensorHandleResolve ( cTensorHandle, status)
478
- checkOk ( status)
479
- TF_DeleteStatus ( status)
480
- internalConsistencyCheck ( cTensor != nil )
481
- debugLog ( " # of dims is \( TF_NumDims ( cTensor!) ) " )
482
- debugLog ( " Returning a shaped array. " )
483
- self . init ( owning: cTensor!)
484
- }
485
- }
486
-
487
395
public extension ShapedArray {
488
396
/// The number of dimensions of the array.
489
397
var rank : Int {
@@ -505,35 +413,24 @@ public extension ShapedArray {
505
413
/// - Precondition: The number of scalars must equal the product of the dimensions of the shape.
506
414
init ( shape: __owned [ Int ] , scalars: __owned [ Scalar ] ) {
507
415
precondition ( shape. reduce ( 1 , * ) == scalars. count, " Scalar count mismatch. " )
508
- let buffer = TensorBuffer < Scalar > ( allocation : . native ( . init ( scalars) ) , count : scalars . count )
416
+ let buffer = ArrayTensorBuffer < Scalar > ( scalars)
509
417
self . init ( buffer: buffer, shape: shape)
510
418
}
511
419
512
420
/// Creates a `ShapedArray` with the specified shape and sequence of scalars in row-major order.
513
421
/// - Precondition: The number of scalars must equal the product of the dimensions of the shape.
514
422
init < S: Sequence > ( shape: __owned [ Int ] , scalars: __shared S) where S. Element == Scalar {
515
423
let scalarCount = shape. reduce ( 1 , * )
516
- let buffer = TensorBuffer< Scalar> . create( count: scalarCount) { bufferPointer in
517
- let pointer = bufferPointer. baseAddress!
518
- // TODO: Refactor with better pointer initializers in Swift 4.1.
519
- var i = 0
520
- for scalar in scalars {
521
- guard i < scalarCount else { break }
522
- pointer. advanced ( by: i) . initialize ( to: scalar)
523
- i += 1
524
- }
525
- // If the sequence has fewer elements than the shape needs, this is a precondition
526
- // failure.
527
- precondition (
528
- i == scalarCount,
529
- " The sequence has fewer elements than needed by the shape. " )
530
- }
424
+ let buffer = ArrayTensorBuffer < Scalar > ( [ Scalar] ( scalars) )
425
+ precondition (
426
+ buffer. count == scalarCount,
427
+ " The sequence has fewer elements than needed by the shape. " )
531
428
self . init ( buffer: buffer, shape: shape)
532
429
}
533
430
534
431
/// Creates a `ShapedArray` from a scalar value.
535
432
init ( _ scalar: __owned Scalar) {
536
- self . init ( buffer: TensorBuffer ( allocation : . native ( . init ( [ scalar] ) ) , count : 1 ) , shape: [ ] )
433
+ self . init ( buffer: ArrayTensorBuffer ( [ scalar] ) , shape: [ ] )
537
434
}
538
435
539
436
/// Creates a `ShapedArray` with the specified shape and a single, repeated scalar value.
@@ -552,9 +449,7 @@ public extension ShapedArray {
552
449
/// - shape: The shape of the `ShapedArray`.
553
450
init ( repeating repeatedValue: __owned Scalar, shape: __owned [ Int ] ) {
554
451
let scalarCount = shape. reduce ( 1 , * )
555
- let buffer = TensorBuffer < Scalar > (
556
- allocation: . native( . init( Array ( repeating: repeatedValue, count: scalarCount) ) ) ,
557
- count: scalarCount)
452
+ let buffer = ArrayTensorBuffer < Scalar > ( Array ( repeating: repeatedValue, count: scalarCount) )
558
453
self . init ( buffer: buffer, shape: shape)
559
454
}
560
455
}
@@ -656,42 +551,6 @@ public extension ShapedArray {
656
551
}
657
552
}
658
553
659
- // Tensor conversion.
660
- extension ShapedArray where Scalar: TensorFlowScalar {
661
- var byteCount : Int {
662
- return MemoryLayout < Scalar > . stride * scalarCount
663
- }
664
-
665
- @usableFromInline
666
- __consuming func makeTensorHandle( ) -> TensorHandle < Scalar > {
667
- // This initializer is designed to optimize conversion from TF-allocated
668
- // `ShapedArray` instances.
669
- switch buffer. allocation {
670
- case let . native( box) :
671
- precondition (
672
- rank <= Int ( Int32 . max) ,
673
- " Conversion to TensorHandle is undefined when rank exceeds `Int32.max`. " )
674
- precondition (
675
- shape. allSatisfy { $0 <= Int ( Int32 . max) } ,
676
- " Conversion to TensorHandle is undefined when shape dimensions exceed `Int32.max`. " )
677
- return TensorHandle < Scalar > (
678
- shape: shape,
679
- scalarsInitializer: { addr in
680
- addr. initialize ( from: box. array, count: scalarCount)
681
- } )
682
- case let . tensorFlow( cTensor) :
683
- return TensorHandle ( copyingFromCTensor: cTensor)
684
- }
685
- }
686
- }
687
-
688
- // Tensor conversion.
689
- public extension Tensor {
690
- init ( _ array: __owned ShapedArray< Scalar > ) {
691
- self . init ( handle: array. makeTensorHandle ( ) )
692
- }
693
- }
694
-
695
554
// Array literal conversion.
696
555
extension ShapedArray : ExpressibleByArrayLiteral where Scalar: TensorFlowScalar {
697
556
public typealias ArrayLiteralElement = _TensorElementLiteral < Scalar >
0 commit comments