13
13
// limitations under the License.
14
14
15
15
import CTensorFlow
16
- import Foundation
17
16
18
17
/// A TensorFlow checkpoint file reader.
19
18
public class TensorFlowCheckpointReader {
20
- @ usableFromInline internal let status : OpaquePointer
21
- @ usableFromInline internal let handle : OpaquePointer
19
+ internal let status : OpaquePointer
20
+ internal let handle : OpaquePointer
22
21
23
- /// URL of the checkpoint file.
24
- public let checkpointPath : URL
22
+ /// The path to the checkpoint file.
23
+ public let checkpointPath : String
25
24
26
- /// Number of tensors stored in the checkpoint.
25
+ /// The number of tensors stored in the checkpoint.
27
26
public var tensorCount : Int { Int ( TF_CheckpointReaderSize ( handle) ) }
28
27
29
- /// Names of the tensors stored in the checkpoint.
28
+ /// The names of the tensors stored in the checkpoint.
30
29
public var tensorNames : [ String ] {
31
30
( 0 ..< tensorCount) . map {
32
31
String ( cString: TF_CheckpointReaderGetVariable ( handle, Int32 ( $0) ) )
@@ -36,28 +35,26 @@ public class TensorFlowCheckpointReader {
36
35
/// Creates a new TensorFlow checkpoint reader.
37
36
///
38
37
/// - Arguments:
39
- /// - checkpointPath: URL of the checkpoint file.
40
- @inlinable
41
- public init ? ( checkpointPath: URL ) {
38
+ /// - checkpointPath: Path to the checkpoint file.
39
+ public init ( checkpointPath: String ) {
42
40
self . status = TF_NewStatus ( )
43
- self . handle = TF_NewCheckpointReader ( checkpointPath. path , status)
41
+ self . handle = TF_NewCheckpointReader ( checkpointPath, status)
44
42
checkOk ( status)
45
43
self . checkpointPath = checkpointPath
46
44
}
47
45
48
46
deinit {
47
+ TF_DeleteStatus ( status)
49
48
TF_DeleteCheckpointReader ( handle)
50
49
}
51
50
52
51
/// Returns `true` if the checkpoint contains a tensor with the provided name.
53
- @inlinable
54
- public func contains( tensorNamed name: String ) -> Bool {
52
+ public func containsTensor( named name: String ) -> Bool {
55
53
TF_CheckpointReaderHasTensor ( handle, name) > 0
56
54
}
57
55
58
56
/// Returns the shape of the tensor with the provided name stored in the checkpoint.
59
- @inlinable
60
- public func shape( ofTensorNamed name: String ) -> TensorShape {
57
+ public func shapeOfTensor( named name: String ) -> TensorShape {
61
58
let rank = TF_CheckpointReaderGetVariableNumDims ( handle, name)
62
59
let dimensions = UnsafeMutablePointer< Int64> . allocate( capacity: Int ( rank) )
63
60
defer { dimensions. deallocate ( ) }
@@ -67,16 +64,30 @@ public class TensorFlowCheckpointReader {
67
64
return TensorShape ( [ Int64] ( dimensionsBufferPointer) . map ( Int . init) )
68
65
}
69
66
70
- /// Returns the data type of the tensor with the provided name stored in the checkpoint.
71
- @inlinable
72
- public func dataType( ofTensorNamed name: String ) -> TensorDataType {
73
- TensorDataType ( TF_CheckpointReaderGetVariableDataType ( handle, name) )
67
+ /// Returns the scalar type of the tensor with the provided name stored in the checkpoint.
68
+ public func scalarTypeOfTensor( named name: String ) -> Any . Type {
69
+ let dataType = TensorDataType ( TF_CheckpointReaderGetVariableDataType ( handle, name) )
70
+ switch dataType. _cDataType {
71
+ case TF_BOOL: return Bool . self
72
+ case TF_INT8: return Int8 . self
73
+ case TF_UINT8: return UInt8 . self
74
+ case TF_INT16: return Int16 . self
75
+ case TF_UINT16: return UInt16 . self
76
+ case TF_INT32: return Int32 . self
77
+ case TF_UINT32: return UInt32 . self
78
+ case TF_INT64: return Int64 . self
79
+ case TF_UINT64: return UInt64 . self
80
+ case TF_BFLOAT16: return BFloat16 . self
81
+ case TF_FLOAT: return Float . self
82
+ case TF_DOUBLE: return Double . self
83
+ case TF_STRING: return String . self
84
+ default : fatalError ( " Unhandled type: \( dataType) " )
85
+ }
74
86
}
75
87
76
88
/// Loads and returns the value of the tensor with the provided name stored in the checkpoint.
77
- @inlinable
78
- public func load< Scalar: _TensorFlowDataTypeCompatible > (
79
- tensorNamed name: String
89
+ public func loadTensor< Scalar: _TensorFlowDataTypeCompatible > (
90
+ named name: String
80
91
) -> ShapedArray < Scalar > {
81
92
let pointer = TF_CheckpointReaderGetTensor ( handle, name, status)
82
93
checkOk ( status)
0 commit comments