-
Notifications
You must be signed in to change notification settings - Fork 137
Added support for 'Tensor.batchGathering(atIndices:)'. #157
Conversation
Co-Authored-By: Richard Wei <[email protected]>
Co-Authored-By: Richard Wei <[email protected]>
Co-Authored-By: Richard Wei <[email protected]>
@rxwei I removed I'd prefer we call it |
Nvm the |
@rxwei I updated this PR so that all tests pass, but there is one issue that came up. I had to use |
|
Right, I also noticed the same problem with the new elementary function support. Thanks for the update! This should be ready to merge then, along with the few other PRs I updated today. I also ran CI tests on them. |
Just got back from a conference. Will get to these PRs tonight! |
Thanks Richard! |
@rxwei this also makes use of
Tensor.withoutDerivative
whereScalar
is not differentiable. How do you think we should avoid these uses here?I also have a more general version of this that requires control flow AD so instead of committing it as non-differentiable but more general function now, I'll leave that one for later.