-
Notifications
You must be signed in to change notification settings - Fork 137
Added support for a 'Tensor.gathering(where:)'. #156
Conversation
Co-Authored-By: Richard Wei <[email protected]>
Co-Authored-By: Richard Wei <[email protected]>
Yes, AD has been doing this all along, powered by activity analysis. However, the issue with the ternary operator expression is that AD does not yet support functions with any control flow, even when the control flow is not part of the code that should be differentiated. A ternary operator expression is a control flow construct in Swift. A workaround is to wrap the control flow in a thunk and call it directly to get a value. let posAxis = { axis < 0 ? axis + rank : axis }() |
I see. That makes sense. Although I believe that the example in #157 does not involve control flow. Also, the thunk approach does not work in this case. This is what I tried previously. I get this:
|
This is because the thunk captures let rank = rank
let posAxis = { axis < 0 ? axis + rank : axis }() |
Oh I see. That makes a lot more sense now. I updated this PR. |
@@ -335,6 +335,107 @@ public extension Tensor { | |||
static func ++ (lhs: Tensor, rhs: Tensor) -> Tensor { | |||
return lhs.concatenated(with: rhs) | |||
} | |||
|
|||
/// Gathers slices of this tensor at `indices` along the `axis` dimension. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doc comments for non-mutating methods often starts with Returns a ... by <verb>ing ...
instead of <Verb>s ...
. Quite a few doc comments in the library do not follow this guideline yet and should be updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. I'll try to go through them tomorrow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Filed #160 .
All tests passed. |
@rxwei How can we make this differentiable? The problem is with the line:
but that should not affect differentiability of the function as a whole.
In general, I feel AD should be capable of checking whether certain statements depend on the variables wrt to which we're differentiating and not fail in cases where they do not (like this op). Is that something that's planned? It would also avoid using the thunk hack in many places (e.g., with
.map
calls on arrays for creating shapes).