Skip to content

Conversation

wyjw
Copy link
Contributor

@wyjw wyjw commented Jul 28, 2019

As in #70, an implementation of corrcoef.

There is a small issue during testing in the clipping of a matrix of all nan, which I have described in #1072.

Copy link
Collaborator

@hawkinsp hawkinsp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There appear to be test failures on CPU. Do you think these are all the same bug as #1072 ? As far as I can see that bug only reproduces on GPU, but the NaN disagreement here clearly is happening on CPU since that's the backend our Travis CI builder is using. Is that the case?

Is there a simple test case for bad np.clip semantics that reproduces on CPU too? (You can try CPU by setting the environment variable CUDA_VISIBLE_DEVICES=0.)

Either way, we might need to temporarily avoid generating NaNs in the test to check this in.


@_wraps(onp.corrcoef)
def corrcoef(x, y=None, rowvar=True, bias=None, ddof=None):
msg = ("jax.numpy.cov not implemented for nontrivial {}. "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two things:
a) this should probably say "corrcoef" not "cov".
b) can you please move the message into the if condition? No need to allocate a string in the non-error case. Alternatively, you could just omit the error completely, assuming that all we are missing is the corresponding code in jax.numpy.cov.

if y is not None: raise NotImplementedError(msg.format('y'))

c = cov(x, y, rowvar)
try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to check the number of dimensions explicitly rather than relying on an exception from np.diag.

@googlebot
Copy link
Collaborator

We found a Contributor License Agreement for you (the sender of this pull request), but were unable to find agreements for all the commit author(s) or Co-authors. If you authored these, maybe you used a different email address in the git commits than was used to sign the CLA (login here to double check)? If these were authored by someone else, then they will need to sign a CLA as well, and confirm that they're okay with these being contributed to Google.
In order to pass this check, please resolve this problem and have the pull request author add another comment and the bot will run again. If the bot doesn't comment, it means it doesn't think anything has changed.

ℹ️ Googlers: Go here for more info.

@wyjw
Copy link
Contributor Author

wyjw commented Jul 29, 2019

I'm sorry I messed up something with my local branch. I'll try to fix it. I'm new to this.

@wyjw
Copy link
Contributor Author

wyjw commented Jul 29, 2019

Once again I apologize for the spam of commits. I have changed the code based on the reviews.

As far as my attempt with np.clip and nan go, it still seems to give an error on cpu.

import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import jax.numpy as np
a = np.full((3,3,), np.nan)
np.clip(a, -1, 1)

If it helps, the OS I am using is Ubuntu 18.04, Python version 3.6.8 and jax is updated to the latest version.

@googlebot
Copy link
Collaborator

CLAs look good, thanks!

ℹ️ Googlers: Go here for more info.

@hawkinsp
Copy link
Collaborator

What version of jaxlib do you have? For me, your test case prints:

DeviceArray([[nan, nan, nan],
             [nan, nan, nan],
             [nan, nan, nan]], dtype=float32)

which seems like the right thing (and matches numpy).

@_wraps(onp.corrcoef)
def corrcoef(x, y=None, rowvar=True, bias=None, ddof=None):
c = cov(x, y, rowvar)
if isscalar(c) or (len(c_shape) != 1 and len(c_shape) != 2):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't define c_shape, but len(shape(c)) == 0 would seem to cover it.

Note that JAX doesn't distinguish between 0D arrays and scalars.

@wyjw
Copy link
Contributor Author

wyjw commented Jul 29, 2019

I have jaxlib 0.1.22. On my machine, the code using cpu still prints a square matrix of -1s.

It seems relevant that the failed Travis case is a mismatch between the two matrices:


self = <lax_numpy_test.LaxBackedNumpyTests testMethod=testCorrCoef_shape=(10, 5)_dtype=<class 'numpy.bool_'>_rowvar=False_ddof=None_bias=False>
x = array([[nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan]], dtype=float32)
y = array([[nan, nan, nan, nan, -1.],
       [nan, nan, nan, nan, -1.],
       [nan, nan, nan, nan, -1.],
       [nan, nan, nan, nan, -1.],
       [nan, nan, nan, nan, -1.]], dtype=float32)
check_dtypes = True, atol = 0.01, rtol = 0.01

although it could be something else.

@hawkinsp
Copy link
Collaborator

hawkinsp commented Jul 29, 2019

Oops! I had an important typo before. This

os.environ["CUDA_VISIBLE_DEVICES"]="0"

should be

os.environ["CUDA_VISIBLE_DEVICES"]=""

So I think that repro is still a GPU bug. It's likely also that setting

XLA_FLAGS="--xla_gpu_enable_fast_min_max=false"

would fix the problem.

I'm still wondering what's happening for the CPU case, though; that sounds like a different issue. (In general NaN semantics aren't heavily tested; for your immediate needs the right fix is to change the test case not to generate NaNs.)

@wyjw
Copy link
Contributor Author

wyjw commented Jul 29, 2019

The NaNs were generated by the standard deviation being 0 for some matrices, and calculating the coefficient was done by dividing by the standard deviation. I added an if-case in the test, and it should work now.

Thanks!

Copy link
Collaborator

@hawkinsp hawkinsp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thanks for the PR. One more tiny suggestion.

@_wraps(onp.corrcoef)
def corrcoef(x, y=None, rowvar=True, bias=None, ddof=None):
c = cov(x, y, rowvar)
if isscalar(c) or (len(shape(c)) != 1 and len(shape(c)) != 2):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JAX doesn't distinguish between scalars and 0D arrays. Perhaps there is value in testing for both (e.g., if a numpy classic array is passed here).

Why the two length tests? Why not len(shape(c) == 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@hawkinsp
Copy link
Collaborator

Looks great! Thanks for the PR!

@hawkinsp hawkinsp merged commit 96dd2e6 into jax-ml:master Jul 30, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants