13
13
// limitations under the License.
14
14
15
15
import Swift
16
- import CTensorFlow
17
16
18
17
//===------------------------------------------------------------------------------------------===//
19
18
// TensorBuffer
@@ -25,101 +24,43 @@ import CTensorFlow
25
24
/// TensorFlow. In either mode, the buffer object owns the memory and will deallocate it on
26
25
/// `deinit`.
27
26
@usableFromInline
28
- internal final class TensorBuffer < Scalar> {
29
- typealias Shape = [ Int ]
30
-
31
- /// A reference type wrapping a Swift Array.
32
- /// - Note: An array is used as the native storage for `TensorBuffer`. To make in-place mutation
33
- /// possible when the array is stored in an enumeration value, the array must be wrapped in a
34
- /// reference type.
35
- @usableFromInline
36
- final class BoxedArray {
37
- var array : [ Scalar ]
38
-
39
- init ( _ array: __owned [ Scalar ] ) {
40
- self . array = array
41
- }
42
- }
43
-
44
- enum Allocation {
45
- case native( BoxedArray )
46
- case tensorFlow( CTensor )
47
- }
27
+ internal class TensorBuffer < Scalar> {
28
+ let count : Int
48
29
49
- let allocation : Allocation
50
- let count : Int
30
+ init ( count: Int ) { self . count = count }
51
31
52
- deinit {
53
- debugLog ( " De-initializing tensor buffer. " )
54
- switch allocation {
55
- case . native:
56
- debugLog ( " Deallocating underlying buffer. " )
57
- case let . tensorFlow( cTensor) :
58
- debugLog ( " Deleting underlying tensor. " )
59
- TF_DeleteTensor ( cTensor)
60
- }
61
- debugLog ( " Returning from deinit of TensorBuffer. " )
32
+ func withUnsafeBufferPointer< R> (
33
+ _ body: ( UnsafeBufferPointer < Scalar > ) throws -> R
34
+ ) rethrows -> R {
35
+ fatalError ( " TensorBuffer should not be constructed directly. " )
62
36
}
63
37
64
- init ( allocation: Allocation , count: Int ) {
65
- self . allocation = allocation
66
- self . count = count
38
+ func withUnsafeMutableBufferPointer< R> (
39
+ _ body: ( inout UnsafeMutableBufferPointer < Scalar > ) throws -> R
40
+ ) rethrows -> R {
41
+ fatalError ( " TensorBuffer should not be constructed directly. " )
67
42
}
68
43
}
69
44
70
- // TF Tensor-specific initializer.
71
- extension TensorBuffer where Scalar: _TensorFlowDataTypeCompatible {
72
- /// Creates a local tensor buffer from a C `TF_Tensor*` value and takes ownership of the value.
73
- convenience init ( owning cTensor: CTensor , count: Int ) {
74
- debugLog ( " Initializing TensorBuffer with a cTensor of \( count) elements. " )
75
- let actualCount = ( 0 ..< TF_NumDims ( cTensor) ) . reduce ( 1 ) { accumulator, next in
76
- accumulator * Int( TF_Dim ( cTensor, next) )
77
- }
78
- assert ( actualCount == count)
79
- self . init ( allocation: . tensorFlow( cTensor) , count: count)
80
- }
81
- }
45
+ // TensorBuffer backed by a native swift array.
46
+ internal class ArrayTensorBuffer < Scalar> : TensorBuffer < Scalar > {
47
+ var array : [ Scalar ]
82
48
83
- // Factory methods.
84
- extension TensorBuffer {
85
- static func create(
86
- count: Int ,
87
- withInitializer body: ( UnsafeMutableBufferPointer < Scalar > ) -> Void
88
- ) -> TensorBuffer < Scalar > {
89
- let array = [ Scalar] ( unsafeUninitializedCapacity: count) { buffer, initializedCount in
90
- body ( buffer)
91
- initializedCount = count
92
- }
93
- return TensorBuffer ( allocation: . native( BoxedArray ( array) ) , count: count)
49
+ init ( _ array: __owned [ Scalar ] ) {
50
+ self . array = array
51
+ super. init ( count: array. count)
94
52
}
95
- }
96
53
97
- // Unsafe address accessor.
98
- extension TensorBuffer {
99
- func withUnsafeBufferPointer< R> (
54
+ override func withUnsafeBufferPointer< R> (
100
55
_ body: ( UnsafeBufferPointer < Scalar > ) throws -> R
101
56
) rethrows -> R {
102
- switch allocation {
103
- case let . native( box) :
104
- return try box. array. withUnsafeBufferPointer { pointer in try body ( pointer) }
105
- case let . tensorFlow( cTensor) :
106
- let startAddress = TF_TensorData ( cTensor) . assumingMemoryBound ( to: Scalar . self)
107
- let bufferPointer = UnsafeBufferPointer ( start: startAddress, count: count)
108
- return try body ( bufferPointer)
109
- }
57
+ return try array. withUnsafeBufferPointer ( body)
110
58
}
111
59
112
- func withUnsafeMutableBufferPointer< R> (
60
+ override func withUnsafeMutableBufferPointer< R> (
113
61
_ body: ( inout UnsafeMutableBufferPointer < Scalar > ) throws -> R
114
62
) rethrows -> R {
115
- switch allocation {
116
- case let . native( box) :
117
- return try box. array. withUnsafeMutableBufferPointer { pointer in try body ( & pointer) }
118
- case let . tensorFlow( cTensor) :
119
- let startAddress = TF_TensorData ( cTensor) . assumingMemoryBound ( to: Scalar . self)
120
- var bufferPointer = UnsafeMutableBufferPointer ( start: startAddress, count: count)
121
- return try body ( & bufferPointer)
122
- }
63
+ return try array. withUnsafeMutableBufferPointer ( body)
123
64
}
124
65
}
125
66
@@ -444,46 +385,12 @@ fileprivate extension ShapedArray {
444
385
if isKnownUniquelyReferenced ( & buffer) { return }
445
386
let oldBuffer = buffer
446
387
debugLog ( " Unique reference check " )
447
- buffer = TensorBuffer . create ( count: scalarCount) { bufferPointer in
448
- let pointer = bufferPointer. baseAddress!
449
- oldBuffer. withUnsafeBufferPointer { oldBufferPointer in
450
- let oldPointer = oldBufferPointer. baseAddress!
451
- pointer. initialize ( from: oldPointer, count: scalarCount)
452
- }
388
+ buffer = oldBuffer. withUnsafeBufferPointer { oldBufferPointer in
389
+ ArrayTensorBuffer < Scalar > ( [ Scalar] ( oldBufferPointer) )
453
390
}
454
391
}
455
392
}
456
393
457
- internal extension ShapedArray where Scalar: _TensorFlowDataTypeCompatible {
458
- @usableFromInline
459
- init ( owning cTensor: CTensor ) {
460
- // Including \(Scalar.self) into the message would cause non-deterministic crashes.
461
- debugLog ( " Initializing ShapedArray from CTensor. " )
462
- shape = ( 0 ..< TF_NumDims ( cTensor) ) . map { Int ( TF_Dim ( cTensor, $0) ) }
463
- if _RuntimeConfig. printsDebugLog {
464
- // Without this local variable, passing the string directly into debugLog() would not
465
- // work, because 'self' is captured by the auto closure param in debugLog().
466
- let shapeStr = " The shape is \( shape) . "
467
- debugLog ( shapeStr)
468
- }
469
- buffer = TensorBuffer ( owning: cTensor, count: shape. reduce ( 1 , * ) )
470
- debugLog ( " Done initializing ShapedArray from CTensor. " )
471
- }
472
-
473
- @usableFromInline
474
- @inline ( never)
475
- init ( cTensorHandle: CTensorHandle ) {
476
- let status = TF_NewStatus ( )
477
- let cTensor = TFE_TensorHandleResolve ( cTensorHandle, status)
478
- checkOk ( status)
479
- TF_DeleteStatus ( status)
480
- internalConsistencyCheck ( cTensor != nil )
481
- debugLog ( " # of dims is \( TF_NumDims ( cTensor!) ) " )
482
- debugLog ( " Returning a shaped array. " )
483
- self . init ( owning: cTensor!)
484
- }
485
- }
486
-
487
394
public extension ShapedArray {
488
395
/// The number of dimensions of the array.
489
396
var rank : Int {
@@ -505,35 +412,24 @@ public extension ShapedArray {
505
412
/// - Precondition: The number of scalars must equal the product of the dimensions of the shape.
506
413
init ( shape: __owned [ Int ] , scalars: __owned [ Scalar ] ) {
507
414
precondition ( shape. reduce ( 1 , * ) == scalars. count, " Scalar count mismatch. " )
508
- let buffer = TensorBuffer < Scalar > ( allocation : . native ( . init ( scalars) ) , count : scalars . count )
415
+ let buffer = ArrayTensorBuffer < Scalar > ( scalars)
509
416
self . init ( buffer: buffer, shape: shape)
510
417
}
511
418
512
419
/// Creates a `ShapedArray` with the specified shape and sequence of scalars in row-major order.
513
420
/// - Precondition: The number of scalars must equal the product of the dimensions of the shape.
514
421
init < S: Sequence > ( shape: __owned [ Int ] , scalars: __shared S) where S. Element == Scalar {
515
422
let scalarCount = shape. reduce ( 1 , * )
516
- let buffer = TensorBuffer< Scalar> . create( count: scalarCount) { bufferPointer in
517
- let pointer = bufferPointer. baseAddress!
518
- // TODO: Refactor with better pointer initializers in Swift 4.1.
519
- var i = 0
520
- for scalar in scalars {
521
- guard i < scalarCount else { break }
522
- pointer. advanced ( by: i) . initialize ( to: scalar)
523
- i += 1
524
- }
525
- // If the sequence has fewer elements than the shape needs, this is a precondition
526
- // failure.
527
- precondition (
528
- i == scalarCount,
529
- " The sequence has fewer elements than needed by the shape. " )
530
- }
423
+ let buffer = ArrayTensorBuffer < Scalar > ( [ Scalar] ( scalars) )
424
+ precondition (
425
+ buffer. count == scalarCount,
426
+ " The sequence has fewer elements than needed by the shape. " )
531
427
self . init ( buffer: buffer, shape: shape)
532
428
}
533
429
534
430
/// Creates a `ShapedArray` from a scalar value.
535
431
init ( _ scalar: __owned Scalar) {
536
- self . init ( buffer: TensorBuffer ( allocation : . native ( . init ( [ scalar] ) ) , count : 1 ) , shape: [ ] )
432
+ self . init ( buffer: ArrayTensorBuffer ( [ scalar] ) , shape: [ ] )
537
433
}
538
434
539
435
/// Creates a `ShapedArray` with the specified shape and a single, repeated scalar value.
@@ -552,9 +448,7 @@ public extension ShapedArray {
552
448
/// - shape: The shape of the `ShapedArray`.
553
449
init ( repeating repeatedValue: __owned Scalar, shape: __owned [ Int ] ) {
554
450
let scalarCount = shape. reduce ( 1 , * )
555
- let buffer = TensorBuffer < Scalar > (
556
- allocation: . native( . init( Array ( repeating: repeatedValue, count: scalarCount) ) ) ,
557
- count: scalarCount)
451
+ let buffer = ArrayTensorBuffer < Scalar > ( Array ( repeating: repeatedValue, count: scalarCount) )
558
452
self . init ( buffer: buffer, shape: shape)
559
453
}
560
454
}
@@ -656,42 +550,6 @@ public extension ShapedArray {
656
550
}
657
551
}
658
552
659
- // Tensor conversion.
660
- extension ShapedArray where Scalar: TensorFlowScalar {
661
- var byteCount : Int {
662
- return MemoryLayout < Scalar > . stride * scalarCount
663
- }
664
-
665
- @usableFromInline
666
- __consuming func makeTensorHandle( ) -> TensorHandle < Scalar > {
667
- // This initializer is designed to optimize conversion from TF-allocated
668
- // `ShapedArray` instances.
669
- switch buffer. allocation {
670
- case let . native( box) :
671
- precondition (
672
- rank <= Int ( Int32 . max) ,
673
- " Conversion to TensorHandle is undefined when rank exceeds `Int32.max`. " )
674
- precondition (
675
- shape. allSatisfy { $0 <= Int ( Int32 . max) } ,
676
- " Conversion to TensorHandle is undefined when shape dimensions exceed `Int32.max`. " )
677
- return TensorHandle < Scalar > (
678
- shape: shape,
679
- scalarsInitializer: { addr in
680
- addr. initialize ( from: box. array, count: scalarCount)
681
- } )
682
- case let . tensorFlow( cTensor) :
683
- return TensorHandle ( copyingFromCTensor: cTensor)
684
- }
685
- }
686
- }
687
-
688
- // Tensor conversion.
689
- public extension Tensor {
690
- init ( _ array: __owned ShapedArray< Scalar > ) {
691
- self . init ( handle: array. makeTensorHandle ( ) )
692
- }
693
- }
694
-
695
553
// Array literal conversion.
696
554
extension ShapedArray : ExpressibleByArrayLiteral where Scalar: TensorFlowScalar {
697
555
public typealias ArrayLiteralElement = _TensorElementLiteral < Scalar >
0 commit comments