Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 7e19749

Browse files
eaplataniosdan-zheng
authored andcommitted
Added initial support for a TensorFlow checkpoint file reader. (#529)
1 parent 7ee2265 commit 7e19749

File tree

1 file changed

+85
-0
lines changed

1 file changed

+85
-0
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import CTensorFlow
16+
import Foundation
17+
18+
/// A TensorFlow checkpoint file reader.
19+
public class TensorFlowCheckpointReader {
20+
@usableFromInline internal let status: OpaquePointer
21+
@usableFromInline internal let handle: OpaquePointer
22+
23+
/// URL of the checkpoint file.
24+
public let checkpointPath: URL
25+
26+
/// Number of tensors stored in the checkpoint.
27+
public var tensorCount: Int { Int(TF_CheckpointReaderSize(handle)) }
28+
29+
/// Names of the tensors stored in the checkpoint.
30+
public var tensorNames: [String] {
31+
(0..<tensorCount).map {
32+
String(cString: TF_CheckpointReaderGetVariable(handle, Int32($0)))
33+
}
34+
}
35+
36+
/// Creates a new TensorFlow checkpoint reader.
37+
///
38+
/// - Arguments:
39+
/// - checkpointPath: URL of the checkpoint file.
40+
@inlinable
41+
public init?(checkpointPath: URL) {
42+
self.status = TF_NewStatus()
43+
self.handle = TF_NewCheckpointReader(checkpointPath.path, status)
44+
checkOk(status)
45+
self.checkpointPath = checkpointPath
46+
}
47+
48+
deinit {
49+
TF_DeleteCheckpointReader(handle)
50+
}
51+
52+
/// Returns `true` if the checkpoint contains a tensor with the provided name.
53+
@inlinable
54+
public func contains(tensorNamed name: String) -> Bool {
55+
TF_CheckpointReaderHasTensor(handle, name) > 0
56+
}
57+
58+
/// Returns the shape of the tensor with the provided name stored in the checkpoint.
59+
@inlinable
60+
public func shape(ofTensorNamed name: String) -> TensorShape {
61+
let rank = TF_CheckpointReaderGetVariableNumDims(handle, name)
62+
let dimensions = UnsafeMutablePointer<Int64>.allocate(capacity: Int(rank))
63+
defer { dimensions.deallocate() }
64+
TF_CheckpointReaderGetVariableShape(handle, name, dimensions, rank, status)
65+
checkOk(status)
66+
let dimensionsBufferPointer = UnsafeBufferPointer(start: dimensions, count: Int(rank))
67+
return TensorShape([Int64](dimensionsBufferPointer).map(Int.init))
68+
}
69+
70+
/// Returns the data type of the tensor with the provided name stored in the checkpoint.
71+
@inlinable
72+
public func dataType(ofTensorNamed name: String) -> TensorDataType {
73+
TensorDataType(TF_CheckpointReaderGetVariableDataType(handle, name))
74+
}
75+
76+
/// Loads and returns the value of the tensor with the provided name stored in the checkpoint.
77+
@inlinable
78+
public func load<Scalar: _TensorFlowDataTypeCompatible>(
79+
tensorNamed name: String
80+
) -> ShapedArray<Scalar> {
81+
let pointer = TF_CheckpointReaderGetTensor(handle, name, status)
82+
checkOk(status)
83+
return ShapedArray<Scalar>(owning: pointer!)
84+
}
85+
}

0 commit comments

Comments
 (0)