Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Added support for a 'Tensor.gathering(where:)'. #156

Merged
merged 16 commits into from
May 31, 2019

Conversation

eaplatanios
Copy link
Contributor

@rxwei How can we make this differentiable? The problem is with the line:

let posAxis = axis < 0 ? axis + rank : axis

but that should not affect differentiability of the function as a whole.

In general, I feel AD should be capable of checking whether certain statements depend on the variables wrt to which we're differentiating and not fail in cases where they do not (like this op). Is that something that's planned? It would also avoid using the thunk hack in many places (e.g., with .map calls on arrays for creating shapes).

@rxwei
Copy link
Contributor

rxwei commented May 31, 2019

In general, I feel AD should be capable of checking whether certain statements depend on the variables wrt to which we're differentiating and not fail in cases where they do not (like this op). Is that something that's planned? It would also avoid using the thunk hack in many places (e.g., with .map calls on arrays for creating shapes).

Yes, AD has been doing this all along, powered by activity analysis. However, the issue with the ternary operator expression is that AD does not yet support functions with any control flow, even when the control flow is not part of the code that should be differentiated. A ternary operator expression is a control flow construct in Swift.

A workaround is to wrap the control flow in a thunk and call it directly to get a value.

let posAxis = { axis < 0 ? axis + rank : axis }()

@eaplatanios
Copy link
Contributor Author

I see. That makes sense. Although I believe that the example in #157 does not involve control flow.

Also, the thunk approach does not work in this case. This is what I tried previously. I get this:

note: cannot differentiate through a non-differentiable result; do you want to add '.withoutDerivative()'?
        let posAxis = {axis < 0 ? axis + rank : axis}()
                      ^

@rxwei
Copy link
Contributor

rxwei commented May 31, 2019

This is because the thunk captures self, which makes the activity analysis think that the thunk should be differentiated with respect to the captured variable, which we don't support yet. We can create a local variable rank for now.

let rank = rank
let posAxis = { axis < 0 ? axis + rank : axis }()

@eaplatanios
Copy link
Contributor Author

Oh I see. That makes a lot more sense now. I updated this PR.

@@ -335,6 +335,107 @@ public extension Tensor {
static func ++ (lhs: Tensor, rhs: Tensor) -> Tensor {
return lhs.concatenated(with: rhs)
}

/// Gathers slices of this tensor at `indices` along the `axis` dimension.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc comments for non-mutating methods often starts with Returns a ... by <verb>ing ... instead of <Verb>s .... Quite a few doc comments in the library do not follow this guideline yet and should be updated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I'll try to go through them tomorrow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed #160 .

@rxwei
Copy link
Contributor

rxwei commented May 31, 2019

All tests passed.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants