Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 11ed29d

Browse files
authored
Addressed comments from #529 related to the TensorFlow checkpoint reader. (#531)
1 parent 7e19749 commit 11ed29d

File tree

1 file changed

+33
-22
lines changed

1 file changed

+33
-22
lines changed

Sources/TensorFlow/Core/Serialization.swift

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,19 @@
1313
// limitations under the License.
1414

1515
import CTensorFlow
16-
import Foundation
1716

1817
/// A TensorFlow checkpoint file reader.
1918
public class TensorFlowCheckpointReader {
20-
@usableFromInline internal let status: OpaquePointer
21-
@usableFromInline internal let handle: OpaquePointer
19+
internal let status: OpaquePointer
20+
internal let handle: OpaquePointer
2221

23-
/// URL of the checkpoint file.
24-
public let checkpointPath: URL
22+
/// The path to the checkpoint file.
23+
public let checkpointPath: String
2524

26-
/// Number of tensors stored in the checkpoint.
25+
/// The number of tensors stored in the checkpoint.
2726
public var tensorCount: Int { Int(TF_CheckpointReaderSize(handle)) }
2827

29-
/// Names of the tensors stored in the checkpoint.
28+
/// The names of the tensors stored in the checkpoint.
3029
public var tensorNames: [String] {
3130
(0..<tensorCount).map {
3231
String(cString: TF_CheckpointReaderGetVariable(handle, Int32($0)))
@@ -36,28 +35,26 @@ public class TensorFlowCheckpointReader {
3635
/// Creates a new TensorFlow checkpoint reader.
3736
///
3837
/// - Arguments:
39-
/// - checkpointPath: URL of the checkpoint file.
40-
@inlinable
41-
public init?(checkpointPath: URL) {
38+
/// - checkpointPath: Path to the checkpoint file.
39+
public init(checkpointPath: String) {
4240
self.status = TF_NewStatus()
43-
self.handle = TF_NewCheckpointReader(checkpointPath.path, status)
41+
self.handle = TF_NewCheckpointReader(checkpointPath, status)
4442
checkOk(status)
4543
self.checkpointPath = checkpointPath
4644
}
4745

4846
deinit {
47+
TF_DeleteStatus(status)
4948
TF_DeleteCheckpointReader(handle)
5049
}
5150

5251
/// Returns `true` if the checkpoint contains a tensor with the provided name.
53-
@inlinable
54-
public func contains(tensorNamed name: String) -> Bool {
52+
public func containsTensor(named name: String) -> Bool {
5553
TF_CheckpointReaderHasTensor(handle, name) > 0
5654
}
5755

5856
/// Returns the shape of the tensor with the provided name stored in the checkpoint.
59-
@inlinable
60-
public func shape(ofTensorNamed name: String) -> TensorShape {
57+
public func shapeOfTensor(named name: String) -> TensorShape {
6158
let rank = TF_CheckpointReaderGetVariableNumDims(handle, name)
6259
let dimensions = UnsafeMutablePointer<Int64>.allocate(capacity: Int(rank))
6360
defer { dimensions.deallocate() }
@@ -67,16 +64,30 @@ public class TensorFlowCheckpointReader {
6764
return TensorShape([Int64](dimensionsBufferPointer).map(Int.init))
6865
}
6966

70-
/// Returns the data type of the tensor with the provided name stored in the checkpoint.
71-
@inlinable
72-
public func dataType(ofTensorNamed name: String) -> TensorDataType {
73-
TensorDataType(TF_CheckpointReaderGetVariableDataType(handle, name))
67+
/// Returns the scalar type of the tensor with the provided name stored in the checkpoint.
68+
public func scalarTypeOfTensor(named name: String) -> Any.Type {
69+
let dataType = TensorDataType(TF_CheckpointReaderGetVariableDataType(handle, name))
70+
switch dataType._cDataType {
71+
case TF_BOOL: return Bool.self
72+
case TF_INT8: return Int8.self
73+
case TF_UINT8: return UInt8.self
74+
case TF_INT16: return Int16.self
75+
case TF_UINT16: return UInt16.self
76+
case TF_INT32: return Int32.self
77+
case TF_UINT32: return UInt32.self
78+
case TF_INT64: return Int64.self
79+
case TF_UINT64: return UInt64.self
80+
case TF_BFLOAT16: return BFloat16.self
81+
case TF_FLOAT: return Float.self
82+
case TF_DOUBLE: return Double.self
83+
case TF_STRING: return String.self
84+
default: fatalError("Unhandled type: \(dataType)")
85+
}
7486
}
7587

7688
/// Loads and returns the value of the tensor with the provided name stored in the checkpoint.
77-
@inlinable
78-
public func load<Scalar: _TensorFlowDataTypeCompatible>(
79-
tensorNamed name: String
89+
public func loadTensor<Scalar: _TensorFlowDataTypeCompatible>(
90+
named name: String
8091
) -> ShapedArray<Scalar> {
8192
let pointer = TF_CheckpointReaderGetTensor(handle, name, status)
8293
checkOk(status)

0 commit comments

Comments
 (0)