|
14 | 14 |
|
15 | 15 | import CTensorFlow
|
16 | 16 |
|
| 17 | +/// A protocol representing types that can be mapped to `Array<CTensorHandle>`. |
| 18 | +/// |
| 19 | +/// This protocol is defined separately from `TensorGroup` in order for the number of tensors to be |
| 20 | +/// determined at runtime. For example, `[Tensor<Float>]` may have an unknown number of elements at |
| 21 | +/// compile time. |
| 22 | +/// |
| 23 | +/// This protocol can be derived automatically for structs whose stored properties all conform to |
| 24 | +/// the `TensorGroup` protocol. It cannot be derived automatically for structs whose properties all |
| 25 | +/// conform to `TensorArrayProtocol` due to the constructor requirement (i.e., in such cases it |
| 26 | +/// would be impossible to know how to break down `count` among the stored properties). |
| 27 | +public protocol TensorArrayProtocol { |
| 28 | + /// Writes the tensor handles to `address`, which must be allocated with enough capacity to hold |
| 29 | + /// `_tensorHandleCount` handles. The tensor handles written to `address` are borrowed: this |
| 30 | + /// container still owns them. |
| 31 | + func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) |
| 32 | + |
| 33 | + var _tensorHandleCount: Int32 { get } |
| 34 | + var _typeList: [TensorDataType] { get } |
| 35 | + |
| 36 | + init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int) |
| 37 | +} |
| 38 | + |
| 39 | +/// A protocol representing types that can be mapped to and from `Array<CTensorHandle>`. |
| 40 | +/// |
| 41 | +/// When a `TensorGroup` is used as an argument to a tensor operation, it is passed as an argument |
| 42 | +/// list whose elements are the tensor fields of the type. |
| 43 | +/// |
| 44 | +/// When a `TensorGroup` is returned as a result of a tensor operation, it is initialized with its |
| 45 | +/// tensor fields set to the tensor operation's tensor results. |
| 46 | +public protocol TensorGroup: TensorArrayProtocol { |
| 47 | + |
| 48 | + /// The types of the tensor stored properties in this type. |
| 49 | + static var _typeList: [TensorDataType] { get } |
| 50 | + |
| 51 | + /// Initializes a value of this type, taking ownership of the `_tensorHandleCount` tensors |
| 52 | + /// starting at address `tensorHandles`. |
| 53 | + init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) |
| 54 | +} |
| 55 | + |
17 | 56 | public extension TensorGroup {
|
18 | 57 | /// The number of tensor fields in this type.
|
19 | 58 | static var _tensorHandleCount: Int32 { return Int32(Self._typeList.count) }
|
|
0 commit comments