-
Notifications
You must be signed in to change notification settings - Fork 19
Rewrite of the code generation script. #26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@rxwei So, if we add the constructor |
That makes sense to me! |
Cool, I'll go ahead and implement that as an experiment to see if it works out. |
@rxwei Actually this makes a lot of sense and it revealed something interesting. The extension Array : TensorArrayProtocol where Element : TensorGroup {
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
var ptr = address
for elem in self {
elem._unpackTensorHandles(into: ptr)
ptr = ptr!.advanced(by: Int(elem._tensorHandleCount))
}
}
public var _tensorHandleCount: Int32 {
var count: Int32 = 0
for elem in self { count += elem._tensorHandleCount }
return count
}
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int) {
let size = count / Element._tensorHandleCount
self = Array((0..<size).map { Element(
_owning: tensorHandles[$0 * Element._tensorHandleCount])
})
}
} That is, |
Makes sense, though we had to derive |
In this case init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int) {
precondition(count == _tensorHandleCount)
self.init(_owning: tensorHandles)
} |
Also, a random idea about generating boilerplates: Since each of these ops calls the same TFE functions, would it make sense to define a |
Yes that would be great. I was hoping that we could make the current |
|
@rxwei I just pushed a change that implements this. The only changes required to compile this in stdlib are:
init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int) {
precondition(count == _tensorHandleCount)
self.init(_owning: tensorHandles)
}
extension Array : TensorArrayProtocol where Element : TensorGroup {
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
var ptr = address
for elem in self {
elem._unpackTensorHandles(into: ptr)
ptr = ptr!.advanced(by: Int(elem._tensorHandleCount))
}
}
public var _tensorHandleCount: Int32 {
var count: Int32 = 0
for elem in self { count += elem._tensorHandleCount }
return count
}
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int) {
let size = count / Element._tensorHandleCount
self = Array((0..<size).map { Element(
_owning: tensorHandles[$0 * Element._tensorHandleCount])
})
}
} Regarding the |
Nice! |
|
…handling of input tensor lists for 'eager' mode.
This PR should be ready for review along with swiftlang/swift#24229 . All tests pass locally as the changes are all backwards compatible. |
Could you please fix the merge conflict in |
Done! :) |
@rxwei This should be ready for review as it is backwards compatible and should not break anything in the existing codebase. |
Great. I'd get @pschuh's opinions on this first as I'm less familiar with binding generation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some minor comments.
return self.op.inferred_counts[number_attr] | ||
if number_attr: | ||
return self.swift_name + 'Count' | ||
if self.arg_def.type_list_attr: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't comment on the raw ops, but these appear to be codegenning as #tfops. This is breaking "saveV2" and "restoreV2" because they no longer return anything. I think the original logic was bad here. They should probably return [AnyTensor] or be otherwise blacklisted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It depends on the mode
you use. I currently set it to tfop-eager-fallback
just for backwards compatibility, but the signature should be the same with eager
mode. I don't understand though what is wrong with the following two ops. saveV2
does not return anything, as expected, and restoreV2
returns a value with type Dtypes
that conforms to TensorGroup
. This allows you to save and restore say a struct of tensors and avoids the loss of type information incurred by using [AnyTensor]
. Maybe I am missing something though.
@inlinable @inline(__always)
public static func saveV2<Dtypes: TensorArrayProtocol>(
prefix: StringTensor,
tensorNames: StringTensor,
shapeAndSlices: StringTensor,
tensors: Dtypes
) {
return #tfop("SaveV2",
prefix,
tensorNames,
shapeAndSlices,
tensors,
dtypes$dtype: tensors._typeList)
}
@inlinable @inline(__always)
public static func restoreV2<Dtypes: TensorGroup>(
prefix: StringTensor,
tensorNames: StringTensor,
shapeAndSlices: StringTensor
) -> Dtypes {
let op = TFE_Op("RestoreV2")
let _ = op.addInput(prefix)
let _ = op.addInput(tensorNames)
let _ = op.addInput(shapeAndSlices)
op.setAttr("dtypes", Dtypes._typeList)
return op.execute(Int(Dtypes._typeList.count))
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad, it was just restoreV2 that I had a problem with. You're constraining _typeList to be a static value. This is not useful in the plan that I have. I'll fix it later I guess, but I would prefer disabling it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually that was a debate I had. I ended up with somewhat of a middle ground where if the type whose _typeList
property we want appears as an output arg only, we constrain it to be a TensorGroup
and use a static property. Otherwise, we constrain it to be a TensorArrayProtocol
and use an instance property. It’s just that if it’s an output arg in either case we’d need to unpack the tensor handles and that’s something that TensorGroup
allows us to do. We could disable it for now but I’m curious what use case it doesn’t work for so we can try and think of a better way to do it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to determine the type list at runtime. In this case, I will be serializing and deserializing a dynamic list of tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds like a special use case that's not easy to generalize over the raw ops generation. Given that the current generation script generates somewhat type-safe code, how about we add an untyped overload for restoreV2
which offers the functionality you need?
* Changed 'TensorArrayProtocol' such that it can be used to support output tensor arrays in raw ops. * Added a '_typeList' property to 'TensorArrayProtocol'. Friend PR: tensorflow/swift-bindings#26 .
@pschuh I made the couple fixes you suggested. Currently trying to build and test swiftlang/swift#24425 with this version of the bindings. |
@pschuh I fixed the bug with the number attributes, but I still had to disable I can also confirm that all tests pass on my machine now. |
I also just removed support for the |
@rxwei @pschuh This is an attempt to rewrite the code generation script so that it supports the following features:
mode
that you can set to eithertfop
,eager
, ortfop-eager-fallback
and it allows you to generate ops using either the#tfop
operator or the eager mode C API (based on @pschuh 's previous implementation).tfop-eager-fallback
usestfop
wherever possible andeager
otherwise (e.g., for output lists).VariantHandle
andResourceHandle
are now also supported, allowing us to replace many of the uses of#tfop
in stdlib with calls toRaw
functions.eager
mode.top and
eagermodes.
_tffunc` is used to trace them.Tensor<T>
andStringTensor
. In this case, two functions are generated, one for each case.tfop
andeager
modes share as much as possible and also makes sure that the API remains the same no matter which mode is used (stdlib can be compile with bindings generated in either mode without any changes).Only
list(func)
, ref-valued and complex-valued types are not supported by this script now, but this should be fine as they're not that common. This now covers 1096/1286 ops. Out of the remaining 190 ops, 131 are ref-valued, and so this covers almost everything now.This also helps with the cleaning up of stdlib and transitioning stuff over to swift-apis.
Friend PRs: swiftlang/swift#24261 and tensorflow/swift-apis#109 .