-
Notifications
You must be signed in to change notification settings - Fork 10.5k
Break TensorGroup into InputTensorGroup and OutputTensorGroup. #20188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,51 +16,72 @@ | |
|
||
import CTensorFlow | ||
|
||
/// A protocol for types that can be used as tensor operation inputs and | ||
/// outputs. When a TensorGroup is used as an input, it gets passed to the | ||
/// tensor operation as an input list whose elements are the tensor fields of | ||
/// the type. When a TensorGroup is used as an output, it gets initialized | ||
/// with its tensor fields set to the tensor operation's tensor outputs. | ||
/// A protocol for types that can be used as tensor operation inputs. When a | ||
/// TensorGroup is used as an input, it gets passed to the tensor operation as | ||
/// an input list whose elements are the tensor fields of the type. | ||
/// | ||
/// TODO: Add a derived conformance to TensorGroup so that users don't have | ||
/// to write the conformance themselves. | ||
public protocol TensorGroup { | ||
/// The types of the tensor stored properties in this type. | ||
static var _typeList: [TensorDataType] { get } | ||
|
||
/// This protocol is divided from OutputTensorGroup in order for the number of | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not valid Swift API doc comment because it's implementation-specific, not something the user is expected to understand. If you wanted it to be seen the library implementer, please use double slashes. |
||
/// tensors to be determined at runtime. For example, Array<Tensor<Float>> may | ||
/// have an unknown number of elements compile time. | ||
public protocol InputTensorGroup { | ||
/// Writes the tensor handles to `address`, which must be allocated | ||
/// with enough capacity to hold `_tensorHandleCount` handles. The tensor | ||
/// handles written to `address` are borrowed: this container still | ||
/// owns them. | ||
func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) | ||
|
||
var _inputTensorHandleCount : Int32 { get } | ||
} | ||
|
||
/// A protocol for types that can be used as tensor operation outputs. When a | ||
/// TensorGroup is used as an output, it gets initialized with its tensor fields | ||
/// set to the tensor operation's tensor outputs. | ||
/// The number of tensors must be known at compile time. | ||
public protocol OutputTensorGroup { | ||
/// The types of the tensor stored properties in this type. | ||
static var _outputTypeList: [TensorDataType] { get } | ||
|
||
/// Initializes a value of this type, taking ownership of the | ||
/// `_tensorHandleCount` tensors that are at `tensorHandles`. | ||
init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) | ||
} | ||
|
||
public extension TensorGroup { | ||
public extension OutputTensorGroup { | ||
/// The number of tensor fields in this type. | ||
static var _tensorHandleCount: Int32 { | ||
return Int32(_typeList.count) | ||
static var _outputTensorHandleCount: Int32 { | ||
return Int32(_outputTypeList.count) | ||
} | ||
|
||
/// An array of `nil`s with size equal to `_tensorHandleCount`. The `nil` | ||
/// represents unknown shape. | ||
static var _unknownShapeList: [TensorShape?] { | ||
return Array(repeating: nil, count: Int(_tensorHandleCount)) | ||
return Array(repeating: nil, count: Int(_outputTensorHandleCount)) | ||
} | ||
} | ||
|
||
/// A protocol for types that can be used as tensor operation inputs and | ||
/// outputs. When a TensorGroup is used as an input, it gets passed to the | ||
/// tensor operation as an input list whose elements are the tensor fields of | ||
/// the type. When a TensorGroup is used as an output, it gets initialized | ||
/// with its tensor fields set to the tensor operation's tensor outputs. | ||
/// | ||
/// TODO: Add a derived conformance to TensorGroup so that users don't have | ||
/// to write the conformance themselves. | ||
public protocol TensorGroup : InputTensorGroup & OutputTensorGroup {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not idiomatic. Please change this to a typealias instead. public typealias TensorGroup = InoutTensorGroup & OutputTensorGroup |
||
|
||
//===----------------------------------------------------------------------===// | ||
// Conform standard TensorFlow types to TensorGroup | ||
//===----------------------------------------------------------------------===// | ||
|
||
extension TensorHandle : TensorGroup { | ||
public static var _typeList: [TensorDataType] { | ||
public static var _outputTypeList: [TensorDataType] { | ||
return [Scalar.tensorFlowDataType] | ||
} | ||
|
||
public var _inputTensorHandleCount : Int32 { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove the space before |
||
get { return Int32(TensorHandle._outputTypeList.count) } | ||
} | ||
|
||
public func _unpackTensorHandles( | ||
into address: UnsafeMutablePointer<CTensorHandle>?) { | ||
address!.initialize(to: _cTensorHandle) | ||
|
@@ -72,10 +93,14 @@ extension TensorHandle : TensorGroup { | |
} | ||
|
||
extension ResourceHandle : TensorGroup { | ||
public static var _typeList: [TensorDataType] { | ||
public static var _outputTypeList: [TensorDataType] { | ||
return [TensorDataType(TF_RESOURCE)] | ||
} | ||
|
||
public var _inputTensorHandleCount : Int32 { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove the space before |
||
get { return Int32(ResourceHandle._outputTypeList.count) } | ||
} | ||
|
||
public func _unpackTensorHandles( | ||
into address: UnsafeMutablePointer<CTensorHandle>?) { | ||
address!.initialize(to: _cTensorHandle) | ||
|
@@ -87,10 +112,14 @@ extension ResourceHandle : TensorGroup { | |
} | ||
|
||
extension VariantHandle : TensorGroup { | ||
public static var _typeList: [TensorDataType] { | ||
public static var _outputTypeList: [TensorDataType] { | ||
return [TensorDataType(TF_VARIANT)] | ||
} | ||
|
||
public var _inputTensorHandleCount : Int32 { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove the space before |
||
get { return Int32(VariantHandle._outputTypeList.count) } | ||
} | ||
|
||
public func _unpackTensorHandles( | ||
into address: UnsafeMutablePointer<CTensorHandle>?) { | ||
address!.initialize(to: _cTensorHandle) | ||
|
@@ -102,10 +131,14 @@ extension VariantHandle : TensorGroup { | |
} | ||
|
||
extension Tensor : TensorGroup { | ||
public static var _typeList: [TensorDataType] { | ||
public static var _outputTypeList: [TensorDataType] { | ||
return [Scalar.tensorFlowDataType] | ||
} | ||
|
||
public var _inputTensorHandleCount : Int32 { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove the space before |
||
get { return Int32(Tensor._outputTypeList.count) } | ||
} | ||
|
||
public func _unpackTensorHandles( | ||
into address: UnsafeMutablePointer<CTensorHandle>?) { | ||
address!.initialize(to: handle._cTensorHandle) | ||
|
@@ -117,10 +150,14 @@ extension Tensor : TensorGroup { | |
} | ||
|
||
extension TensorElementLiteral : TensorGroup { | ||
public static var _typeList: [TensorDataType] { | ||
public static var _outputTypeList: [TensorDataType] { | ||
return [Scalar.tensorFlowDataType] | ||
} | ||
|
||
public var _inputTensorHandleCount : Int32 { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove the space before |
||
get { return Int32(TensorElementLiteral._outputTypeList.count) } | ||
} | ||
|
||
public func _unpackTensorHandles( | ||
into address: UnsafeMutablePointer<CTensorHandle>?) { | ||
address!.initialize(to: handle._cTensorHandle) | ||
|
@@ -130,3 +167,18 @@ extension TensorElementLiteral : TensorGroup { | |
self.init(handle: TensorHandle(_owning: tensorHandles!.pointee)) | ||
} | ||
} | ||
|
||
extension Array : InputTensorGroup where Element : InputTensorGroup { | ||
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) { | ||
var ptr = address | ||
for elem in self { | ||
elem._unpackTensorHandles(into: ptr) | ||
ptr = ptr!.advanced(by: Int(elem._inputTensorHandleCount)) | ||
} | ||
} | ||
public var _inputTensorHandleCount : Int32 { get { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please reformat this as the following: public var _inputTensorHandleCount : Int32 {
get {
...
}
} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be even better to do it without the
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it definitely would. I didn't notice there's no There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove space before |
||
var count: Int32 = 0 | ||
for elem in self { count += elem._inputTensorHandleCount } | ||
return count | ||
} } | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"tensor operation inputs" in an implementation-specific comment. It does not have a meaning to the user.