Skip to content

Conversation

dpfau
Copy link
Contributor

@dpfau dpfau commented Jun 18, 2020

This resolves Issue #3487

@dpfau dpfau changed the title Add b and return_sign functionality to scipy.special.logsumexp Add b and return_sign functionality to scipy.special.logsumexp Jun 18, 2020
@jakevdp jakevdp self-assigned this Jun 18, 2020
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for tackling this! Looks great. Just a few comments below:

@dpfau
Copy link
Contributor Author

dpfau commented Jun 19, 2020

I'm working on the broadcasting today. In the meantime, the tests seem to be failing due to a RuntimeWarning, caused by logsumexp taking the log of a negative value when return_sign is false. Since this is just a warning, I don't know why it's causing a test failure.

@hawkinsp
Copy link
Collaborator

We treat warnings as errors in tests (because we want a warning-clean build).

You'll need to avoid or suppress the warning. The jtu.ignore_warning decorator may help.

@dpfau
Copy link
Contributor Author

dpfau commented Jun 21, 2020

Seems like a whole bunch of unrelated tests are failing now. What's up?

@dpfau
Copy link
Contributor Author

dpfau commented Jun 21, 2020

Basically every test that uses tostring in TF is telling me to use tobytes instead, even though it worked fine 2 days ago.

@dpfau dpfau requested a review from jakevdp June 21, 2020 18:05
@jakevdp
Copy link
Collaborator

jakevdp commented Jun 22, 2020

The test failures should be fixed if you rebase your branch on master.

@dpfau dpfau requested a review from jakevdp June 22, 2020 21:04
@jakevdp
Copy link
Collaborator

jakevdp commented Jun 22, 2020

Looks great! There's a stray pdb.set_trace() presumably left over from debugging; if you remove that we can get this merged 😁

@dpfau
Copy link
Contributor Author

dpfau commented Jun 22, 2020

Whoops. Should be fixed now.

@dpfau
Copy link
Contributor Author

dpfau commented Jun 23, 2020

Can we merge this?

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 23, 2020

Thanks! We're not quite ready to merge, unfortunately. I ran this through our internal tests, which exercise the GPU and TPU backends, and we're seeing a number of test failures on GPU/TPU (but not CPU) in cases where return_sign=True and use_b=False.

I'm not entirely certain what may be causing this, but given the random number generators specified in the tests, it's probably related to handling of nan and inf values in the inputs.

@dpfau
Copy link
Contributor Author

dpfau commented Jun 23, 2020

Ah I see. To make the behavior consistent with NumPy, I had it return NaN in the case where return_sign was false but the sign of the result was negative. I could undo that and just have it return the true value without the sign no matter what, but then the tests would have to be changed when comparing against NumPy.

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 23, 2020

I've pasted an example failure below; it's in the comparison with numpy's output. The commonality of all the failures is:

  1. GPU or TPU backend
  2. return_sign=True and use_b=False

I'll have some time later today to help find the root of the issue. If you want to look into it before then, I'd focus on inputs containing some inf and nan values, given the random number generator the test is using in this case (as written, none of the CPU tests are covering inputs with nans and infs).

[device=GPU] LaxBackedScipyTests.testLogSumExp_shapes=float32[(2, 1, 4),(2, 1, 4)]_axis=0_keepdims=False_return_sign=True_use_b_False

Traceback (most recent call last):
  File "<embedded stdlib>/unittest/case.py", line 59, in testPartExecutor
    yield
  File "<embedded stdlib>/unittest/case.py", line 605, in run
    testMethod()
  File "/build/work/474fee93c453555e4cdf942d5ea3c8d371f2/google3/runfiles/google3/third_party/py/absl/testing/parameterized.py", line 282, in bound_param_test
    return test_method(self, **testcase_params)
  File "/build/work/474fee93c453555e4cdf942d5ea3c8d371f2/google3/runfiles/google3/third_party/py/jax/test_util.py", line 365, in test_method_wrapper
    return test_method(self, *args, **kwargs)
  File "<embedded stdlib>/contextlib.py", line 52, in inner
    return func(*args, **kwds)
  File "/build/work/474fee93c453555e4cdf942d5ea3c8d371f2/google3/runfiles/google3/third_party/py/jax/tests/lax_scipy_test.py", line 142, in testLogSumExp
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
  File "/build/work/474fee93c453555e4cdf942d5ea3c8d371f2/google3/runfiles/google3/third_party/py/jax/test_util.py", line 831, in _CheckAgainstNumpy
    canonicalize_dtypes=canonicalize_dtypes)
  File "/build/work/474fee93c453555e4cdf942d5ea3c8d371f2/google3/runfiles/google3/third_party/py/jax/test_util.py", line 756, in assertAllClose
    rtol=rtol, canonicalize_dtypes=canonicalize_dtypes)
  File "/build/work/474fee93c453555e4cdf942d5ea3c8d371f2/google3/runfiles/google3/third_party/py/jax/test_util.py", line 763, in assertAllClose
    self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol)
  File "/build/work/474fee93c453555e4cdf942d5ea3c8d371f2/google3/runfiles/google3/third_party/py/jax/test_util.py", line 730, in assertArraysAllClose
    _assert_numpy_allclose(x, y, atol=atol, rtol=rtol)
  File "/build/work/474fee93c453555e4cdf942d5ea3c8d371f2/google3/runfiles/google3/third_party/py/jax/test_util.py", line 117, in _assert_numpy_allclose
    np.testing.assert_allclose(a, b, **kw)
  File "/build/work/474fee93c453555e4cdf942d5ea3c8d371f2/google3/runfiles/google3/third_party/py/numpy/testing/_private/utils.py", line 1501, in assert_allclose
    verbose=verbose, header=header, equal_nan=equal_nan)
  File "/build/work/474fee93c453555e4cdf942d5ea3c8d371f2/google3/runfiles/google3/third_party/py/numpy/testing/_private/utils.py", line 757, in assert_array_compare
    flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan')
  File "/build/work/474fee93c453555e4cdf942d5ea3c8d371f2/google3/runfiles/google3/third_party/py/numpy/testing/_private/utils.py", line 733, in func_assert_same_pos
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=1e-06, atol=1e-06

x and y nan location mismatch:
 x: array([[ 1., nan, nan,  1.]], dtype=float32)
 y: array([[1., 1., 1., 1.]], dtype=float32)

@dpfau
Copy link
Contributor Author

dpfau commented Jun 23, 2020

So if the input contains -inf values, it actually should still give finite values. I'll take a look.

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 23, 2020

Here's a short example of where this implementation differs from scipy:

import numpy as np
import scipy.special as osp_special
from jax.scipy.special import logsumexp  # in this branch
np.random.seed(0)
x = np.random.rand(3, 4)
x[x < 0.4] = np.nan

arr1, sign1 = logsumexp(jnp.array(x), axis=0, return_sign=True)
arr2, sign2 = osp_special.logsumexp(x, axis=0, return_sign=True)

print(x)
print(sign1)
print(sign2)
[[0.5488135  0.71518937 0.60276338 0.54488318]
 [0.4236548  0.64589411 0.43758721 0.891773  ]
 [0.96366276        nan 0.79172504 0.52889492]]
[1. 1. 1. 1.]
[ 1. nan  1.  1.]

@dpfau
Copy link
Contributor Author

dpfau commented Jun 23, 2020 via email

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 23, 2020

Getting closer... there's still an issue when the data contains large negative numbers: your implementation returns sign=1 where numpy returns sign=0. For example:

import numpy as np
import scipy.special as osp_special
from jax.scipy.special import logsumexp
np.random.seed(0)
x = np.random.rand(3, 4)
x[:, 1] = -np.inf

arr1, sign1 = logsumexp(jnp.array(x), axis=0, return_sign=True)
arr2, sign2 = osp_special.logsumexp(x, axis=0, return_sign=True)

print(sign1)
print(sign2)
[1. 1. 1. 1.]
[1. 0. 1. 1.]

(Side-note: like the previous failures, these come up in the GPU/TPU tests and not the CPU tests, because the current test uses a different random number generator on CPU than on GPU/TPU. I'm not sure why that's the case, but you might look into changing the test to use rand_some_inf_and_nan for CPU tests as well in order to catch these errors more quickly).

@dpfau
Copy link
Contributor Author

dpfau commented Jun 23, 2020

Honestly, the tests pass fine when I swap in rand_some_inf_and_nan. I'll keep trying.

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 23, 2020

Honestly, the tests pass fine when I swap in rand_some_inf_and_nan. I'll keep trying.

This is probably because a failing case is not generated with the default num_generated_cases=10. The default is 25 on github and 100 on borg, I believe, so many more inputs are tested. You can pass, e.g. --num_generated_cases=50 locally to test more cases (see https://jax.readthedocs.io/en/latest/developer.html#running-the-tests for details).

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 23, 2020

It looks like this does it! Thanks for the contribution!

@jakevdp jakevdp merged commit 9d173c6 into jax-ml:master Jun 23, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants