-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Implementation of np.corrcoef #1071
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appear to be test failures on CPU. Do you think these are all the same bug as #1072 ? As far as I can see that bug only reproduces on GPU, but the NaN disagreement here clearly is happening on CPU since that's the backend our Travis CI builder is using. Is that the case?
Is there a simple test case for bad np.clip
semantics that reproduces on CPU too? (You can try CPU by setting the environment variable CUDA_VISIBLE_DEVICES=0
.)
Either way, we might need to temporarily avoid generating NaNs in the test to check this in.
jax/numpy/lax_numpy.py
Outdated
|
||
@_wraps(onp.corrcoef) | ||
def corrcoef(x, y=None, rowvar=True, bias=None, ddof=None): | ||
msg = ("jax.numpy.cov not implemented for nontrivial {}. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two things:
a) this should probably say "corrcoef" not "cov".
b) can you please move the message into the if
condition? No need to allocate a string in the non-error case. Alternatively, you could just omit the error completely, assuming that all we are missing is the corresponding code in jax.numpy.cov
.
jax/numpy/lax_numpy.py
Outdated
if y is not None: raise NotImplementedError(msg.format('y')) | ||
|
||
c = cov(x, y, rowvar) | ||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer to check the number of dimensions explicitly rather than relying on an exception from np.diag
.
We found a Contributor License Agreement for you (the sender of this pull request), but were unable to find agreements for all the commit author(s) or Co-authors. If you authored these, maybe you used a different email address in the git commits than was used to sign the CLA (login here to double check)? If these were authored by someone else, then they will need to sign a CLA as well, and confirm that they're okay with these being contributed to Google. ℹ️ Googlers: Go here for more info. |
I'm sorry I messed up something with my local branch. I'll try to fix it. I'm new to this. |
Once again I apologize for the spam of commits. I have changed the code based on the reviews. As far as my attempt with np.clip and nan go, it still seems to give an error on cpu.
If it helps, the OS I am using is Ubuntu 18.04, Python version 3.6.8 and jax is updated to the latest version. |
CLAs look good, thanks! ℹ️ Googlers: Go here for more info. |
What version of
which seems like the right thing (and matches numpy). |
jax/numpy/lax_numpy.py
Outdated
@_wraps(onp.corrcoef) | ||
def corrcoef(x, y=None, rowvar=True, bias=None, ddof=None): | ||
c = cov(x, y, rowvar) | ||
if isscalar(c) or (len(c_shape) != 1 and len(c_shape) != 2): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't define c_shape
, but len(shape(c)) == 0
would seem to cover it.
Note that JAX doesn't distinguish between 0D arrays and scalars.
I have jaxlib 0.1.22. On my machine, the code using cpu still prints a square matrix of -1s. It seems relevant that the failed Travis case is a mismatch between the two matrices:
although it could be something else. |
Oops! I had an important typo before. This
should be
So I think that repro is still a GPU bug. It's likely also that setting
would fix the problem. I'm still wondering what's happening for the CPU case, though; that sounds like a different issue. (In general NaN semantics aren't heavily tested; for your immediate needs the right fix is to change the test case not to generate NaNs.) |
The NaNs were generated by the standard deviation being 0 for some matrices, and calculating the coefficient was done by dividing by the standard deviation. I added an if-case in the test, and it should work now. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Thanks for the PR. One more tiny suggestion.
jax/numpy/lax_numpy.py
Outdated
@_wraps(onp.corrcoef) | ||
def corrcoef(x, y=None, rowvar=True, bias=None, ddof=None): | ||
c = cov(x, y, rowvar) | ||
if isscalar(c) or (len(shape(c)) != 1 and len(shape(c)) != 2): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
JAX doesn't distinguish between scalars and 0D arrays. Perhaps there is value in testing for both (e.g., if a numpy classic array is passed here).
Why the two length tests? Why not len(shape(c) == 0
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
Looks great! Thanks for the PR! |
As in #70, an implementation of corrcoef.
There is a small issue during testing in the clipping of a matrix of all nan, which I have described in #1072.