Skip to content

Variance along an axis #440

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 37 commits into from

Conversation

LukeMathWalker
Copy link
Member

Morning!

I crafted a small function to compute the variance along an axis, which seemed to be a missing functionality.
It's my first Rust PR, so I am more than open to suggestions and critics.

@jturner314
Copy link
Member

jturner314 commented Apr 24, 2018

Hi @LukeMathWalker. Welcome to Rust, and thanks for the PR! It looks well-written.

There are a few things I'd like to see before this gets merged:

  1. Add a few tests of .var_axis() in tests/array.rs, including a test with complex numbers.
  2. Add a ddof parameter so that the method can calculate both the population variance and sample variance. (For example, see the ddof parameter for numpy.var.)
  3. Reduce the number of heap allocations. (There are two heap allocations in each iteration of the loop. self.subview(axis, i).to_owned() clones the data in the subview. &new_row - &mean performs a heap allocation for its result.)

Here's an example modification for items 2 and 3 (along with a couple of stylistic changes and more docs). (I hope I didn't break anything :-).) It uses azip! to avoid making any heap allocations in the loop. This provides a large performance boost in some cases (e.g. 72 μs → 21 μs on one benchmark).

Is the implementation correct for complex numbers? It's not obvious to me whether or not it is. NumPy handles complex numbers by taking the absolute value before squaring, which seems like a reasonable approach. (I think NumPy's approach follows Uncyclopedia's statement that "The variance is always a nonnegative real number. It is equal to the sum of the variances of the real and imaginary part of the complex random variable".)

Edit: One more question – is there a source that isn't paywalled for the Welford method?

Edit2: The docs should also indicate what conditions cause the method to panic.

/// );
/// ```
pub fn var_axis(&self, axis: Axis) -> Array<A, D::Smaller>
where A: LinalgScalar + ScalarOperand,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. ScalarOperand is not intended to be used like this, so I'd try to find a different way to do it, without that trait.
  2. Variance usually has a ddof parameter, and I think we need to allow for it in some way for var and std.
  3. We need to find a different way to compute this, to avoid using .to_owned() on every row.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by ScalarOperand is not intended used like this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a crutch specifically for some operator impls for arrays. Not all traits have the same role.

@LukeMathWalker
Copy link
Member Author

LukeMathWalker commented May 1, 2018

The Welford online algorithm does not handle complex numbers by default. Nonetheless, if X=a+ib, where a and b are real vectors, Var[X]=Var[a]+Var[b]; we can thus use this function on the real and on the imaginary parts separately and then sum the results.

What is the best way to do this in Rust?
In Python I would have done something similar to what NumPy does, i.e. a type check:

if issubclass(arr.dtype.type, nt.complexfloating):
        x = um.multiply(x, um.conjugate(x), out=x).real
    else:
        x = um.multiply(x, x, out=x)

What is the most idiomatic way to handle it in Rust? Because it looks strange to me to have a type check in a generic function...

@jturner314
Copy link
Member

@LukeMathWalker The way to handle this type of thing in Rust is to have all of the types you want to operate on implement the necessary trait bounds. If necessary, you can define your own traits and implement them for external types. (In this case, we need a trait for calculating the complex conjugate of the value (or getting the real and imaginary parts), and it would also be nice to have a trait define the associated real type (e.g. f64 for c64) so that we could always return a real number instead of a complex number from var_axis.)

After thinking about this some more, I would not object to supporting only real numbers, at least until rust-num/num-complex#2 is resolved. In fact, I'd prefer to support only A: Float for the time being. Here's my reasoning:

  • If we want to support complex numbers before rust-num/num-complex#2 is resolved, then we need to add traits to be able to operate over element types in a generic way. In particular, we'd need traits similar to Conjugate and AssociatedReal from the ndarray-linalg crate. I'd prefer to avoid adding these traits to ndarray because I don't see any obvious uses for them other than variance and standard deviation, so the additional complexity seems to outweigh the benefit of supporting complex numbers in this case.

  • I'd prefer an A: Float constraint over A: Add + Div + ... (or equivalent bound) for two reasons:

    1. It doesn't seem very useful to be able to calculate the variance of integer arrays since the error will be fairly large.
    2. Since Float is not implemented for complex numbers, this also avoids users trying to take the variance of a complex array and getting an incorrect result.

@bluss What are your thoughts?

Other comments on the PR:

  • The most recent commits strip trailing whitespace on some lines unrelated to the changes (e.g. quite a few in src/lib.rs). I prefer not to do this in order to keep a clean history.

  • I noticed that this method will panic if the length of the axis is less than 2. I think we should support axes of length 1 without panicking. What should be behavior for axes of length 0? Panic or return Err/None? I lean towards "panic".

    To support axes of length 1, we can just change the initialization and iteration to:

    let mut count = A::zero();
    let mut mean = Array::zeros(self.dim.remove_axis(axis));
    let mut sum_sq = Array::zeros(self.dim.remove_axis(axis));
    for subview in self.axis_iter(axis) {

    or we could instead add an if statement that checks for the length = 1 case.

@LukeMathWalker
Copy link
Member Author

I have refactored var_axis to accept A: Float + ScalarOperand instead of LinalgScalar + ScalarOperand.
I have also changed the initialization, as @jturner314 suggested, to avoid panicking on axis of length 1.
I opted to have the method panicking for axis of length 0.

I added another panic trigger in the case of ddof greater or equal than the length of the axis (otherwise we might return negative values - it does not make sense for a variance computation).

I am open to implementing the overall trait architecture to support Complex values, if you believe that to be the best course of action.

For the time being I have not dropped the ScalarOperand bound: the alternative is to instantiate count as a one-element array and rely on broadcasting behaviour, which seems a little more obscure to me than a plain scalar division on an array. I am open to implement that change as well if @bluss believes it to be really necessary.

danmack and others added 2 commits May 9, 2018 09:39
Copy link
Member

@jturner314 jturner314 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry about the delay; life has been busy recently.

The Float bound is sufficient; we can remove the ScalarOperand bound by rewriting the division to use mapv. (See the comments below.) Everything else looks good to me.

@LukeMathWalker Will you please make the changes listed in the comments below and squash this PR into a single commit? It would also be nice if you could rebase off of the latest master.

Once that's done, I'll merge this PR unless @bluss has any objections.

panic!("Ddof needs to be strictly smaller than the length \
of the axis you are computing the variance for!")
} else {
sum_sq / (count - ddof)
Copy link
Member

@jturner314 jturner314 May 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can change this line to sum_sq.mapv(|s| s / (count - ddof)) to remove the ScalarOperand bound.

Edit: It might be faster to do this instead:

let dof = count - ddof;
sum_sq.mapv(|s| s / dof)

to avoid recomputing of count - ddof for every element.

/// ```
pub fn var_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
where
A: Float + ScalarOperand,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ScalarOperand bound can be removed.

@@ -14,6 +14,7 @@ use imp_prelude::*;
use numeric_util;

use {
ScalarOperand,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import can be removed.

/// ```
///
/// **Panics** if `ddof` is greater equal than the length of `axis`.
/// **Panics** if `axis` is out of bounds or if lenght of `axis` is zero.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: "lenght" should be "length"

{
let mut count = A::zero();
let mut mean = Array::zeros(self.dim.remove_axis(axis));
let mut sum_sq = Array::zeros(self.dim.remove_axis(axis));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I tried making the other changes, the compiler had trouble inferring the type of this array, so it was necessary to change this to Array::<A, _>::zeros(...).

@LukeMathWalker
Copy link
Member Author

I have made all the edits you suggested to remove the ScalarOperand bound.
I have rebased from master too.
To squash all the commits you can select "Squash and merge" from the Pull Request "Merge" button - I don't want to make a mess ^^"

jturner314 added a commit to jturner314/ndarray that referenced this pull request May 28, 2018
@jturner314
Copy link
Member

Okay, I squashed, merged, and closed this PR.

@LukeMathWalker Thanks for working on this and for your patience! I've wanted to add a variance method to ndarray for a while.

@jturner314 jturner314 closed this May 28, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants