Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

[TF] Removed 'HackyTensorflowMigrationSupport' from apple/swift. #139

Merged
merged 1 commit into from
May 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ xcuserdata
DerivedData/
*.xcodeproj
*~
*.vscode
*.idea

### MacOS ###
.DS_Store
18 changes: 11 additions & 7 deletions Sources/TensorFlow/Core/DataTypes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@

import CTensorFlow

public extension TensorDataType {
var _cDataType: TF_DataType {
return TF_DataType(rawValue: _internalStorageType)
/// A TensorFlow dynamic type value that can be created from types that conform to
/// `TensorFlowScalar`.
// This simply wraps a `TF_DataType` and allows user code to handle
// `TF_DataType` without importing CTensorFlow, which pollutes the namespace
// with TensorFlow C API declarations.
public struct TensorDataType {
public var _cDataType: TF_DataType

@usableFromInline
internal init(_ cDataType: TF_DataType) {
self._cDataType = cDataType
}

init(_ cDataType: TF_DataType) {
self.init(rawValue: cDataType.rawValue)
}
}

@usableFromInline
Expand Down
39 changes: 39 additions & 0 deletions Sources/TensorFlow/Core/TensorGroup.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,45 @@

import CTensorFlow

/// A protocol representing types that can be mapped to `Array<CTensorHandle>`.
///
/// This protocol is defined separately from `TensorGroup` in order for the number of tensors to be
/// determined at runtime. For example, `[Tensor<Float>]` may have an unknown number of elements at
/// compile time.
///
/// This protocol can be derived automatically for structs whose stored properties all conform to
/// the `TensorGroup` protocol. It cannot be derived automatically for structs whose properties all
/// conform to `TensorArrayProtocol` due to the constructor requirement (i.e., in such cases it
/// would be impossible to know how to break down `count` among the stored properties).
public protocol TensorArrayProtocol {
/// Writes the tensor handles to `address`, which must be allocated with enough capacity to hold
/// `_tensorHandleCount` handles. The tensor handles written to `address` are borrowed: this
/// container still owns them.
func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?)

var _tensorHandleCount: Int32 { get }
var _typeList: [TensorDataType] { get }

init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int)
}

/// A protocol representing types that can be mapped to and from `Array<CTensorHandle>`.
///
/// When a `TensorGroup` is used as an argument to a tensor operation, it is passed as an argument
/// list whose elements are the tensor fields of the type.
///
/// When a `TensorGroup` is returned as a result of a tensor operation, it is initialized with its
/// tensor fields set to the tensor operation's tensor results.
public protocol TensorGroup: TensorArrayProtocol {

/// The types of the tensor stored properties in this type.
static var _typeList: [TensorDataType] { get }

/// Initializes a value of this type, taking ownership of the `_tensorHandleCount` tensors
/// starting at address `tensorHandles`.
init(_owning tensorHandles: UnsafePointer<CTensorHandle>?)
}

public extension TensorGroup {
/// The number of tensor fields in this type.
static var _tensorHandleCount: Int32 { return Int32(Self._typeList.count) }
Expand Down