|
| 1 | +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +import CTensorFlow |
| 16 | +import Foundation |
| 17 | + |
| 18 | +/// A TensorFlow checkpoint file reader. |
| 19 | +public class TensorFlowCheckpointReader { |
| 20 | + @usableFromInline internal let status: OpaquePointer |
| 21 | + @usableFromInline internal let handle: OpaquePointer |
| 22 | + |
| 23 | + /// URL of the checkpoint file. |
| 24 | + public let checkpointPath: URL |
| 25 | + |
| 26 | + /// Number of tensors stored in the checkpoint. |
| 27 | + public var tensorCount: Int { Int(TF_CheckpointReaderSize(handle)) } |
| 28 | + |
| 29 | + /// Names of the tensors stored in the checkpoint. |
| 30 | + public var tensorNames: [String] { |
| 31 | + (0..<tensorCount).map { |
| 32 | + String(cString: TF_CheckpointReaderGetVariable(handle, Int32($0))) |
| 33 | + } |
| 34 | + } |
| 35 | + |
| 36 | + /// Creates a new TensorFlow checkpoint reader. |
| 37 | + /// |
| 38 | + /// - Arguments: |
| 39 | + /// - checkpointPath: URL of the checkpoint file. |
| 40 | + @inlinable |
| 41 | + public init?(checkpointPath: URL) { |
| 42 | + self.status = TF_NewStatus() |
| 43 | + self.handle = TF_NewCheckpointReader(checkpointPath.path, status) |
| 44 | + checkOk(status) |
| 45 | + self.checkpointPath = checkpointPath |
| 46 | + } |
| 47 | + |
| 48 | + deinit { |
| 49 | + TF_DeleteCheckpointReader(handle) |
| 50 | + } |
| 51 | + |
| 52 | + /// Returns `true` if the checkpoint contains a tensor with the provided name. |
| 53 | + @inlinable |
| 54 | + public func contains(tensorNamed name: String) -> Bool { |
| 55 | + TF_CheckpointReaderHasTensor(handle, name) > 0 |
| 56 | + } |
| 57 | + |
| 58 | + /// Returns the shape of the tensor with the provided name stored in the checkpoint. |
| 59 | + @inlinable |
| 60 | + public func shape(ofTensorNamed name: String) -> TensorShape { |
| 61 | + let rank = TF_CheckpointReaderGetVariableNumDims(handle, name) |
| 62 | + let dimensions = UnsafeMutablePointer<Int64>.allocate(capacity: Int(rank)) |
| 63 | + defer { dimensions.deallocate() } |
| 64 | + TF_CheckpointReaderGetVariableShape(handle, name, dimensions, rank, status) |
| 65 | + checkOk(status) |
| 66 | + let dimensionsBufferPointer = UnsafeBufferPointer(start: dimensions, count: Int(rank)) |
| 67 | + return TensorShape([Int64](dimensionsBufferPointer).map(Int.init)) |
| 68 | + } |
| 69 | + |
| 70 | + /// Returns the data type of the tensor with the provided name stored in the checkpoint. |
| 71 | + @inlinable |
| 72 | + public func dataType(ofTensorNamed name: String) -> TensorDataType { |
| 73 | + TensorDataType(TF_CheckpointReaderGetVariableDataType(handle, name)) |
| 74 | + } |
| 75 | + |
| 76 | + /// Loads and returns the value of the tensor with the provided name stored in the checkpoint. |
| 77 | + @inlinable |
| 78 | + public func load<Scalar: _TensorFlowDataTypeCompatible>( |
| 79 | + tensorNamed name: String |
| 80 | + ) -> ShapedArray<Scalar> { |
| 81 | + let pointer = TF_CheckpointReaderGetTensor(handle, name, status) |
| 82 | + checkOk(status) |
| 83 | + return ShapedArray<Scalar>(owning: pointer!) |
| 84 | + } |
| 85 | +} |
0 commit comments