-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Add b
and return_sign
functionality to scipy.special.logsumexp
#3488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
b
and return_sign
functionality to scipy.special.logsumexp
b
and return_sign
functionality to scipy.special.logsumexp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for tackling this! Looks great. Just a few comments below:
I'm working on the broadcasting today. In the meantime, the tests seem to be failing due to a RuntimeWarning, caused by logsumexp taking the log of a negative value when return_sign is false. Since this is just a warning, I don't know why it's causing a test failure. |
We treat warnings as errors in tests (because we want a warning-clean build). You'll need to avoid or suppress the warning. The |
Seems like a whole bunch of unrelated tests are failing now. What's up? |
Basically every test that uses |
The test failures should be fixed if you rebase your branch on master. |
Looks great! There's a stray |
Whoops. Should be fixed now. |
Can we merge this? |
Thanks! We're not quite ready to merge, unfortunately. I ran this through our internal tests, which exercise the GPU and TPU backends, and we're seeing a number of test failures on GPU/TPU (but not CPU) in cases where I'm not entirely certain what may be causing this, but given the random number generators specified in the tests, it's probably related to handling of |
Ah I see. To make the behavior consistent with NumPy, I had it return NaN in the case where return_sign was false but the sign of the result was negative. I could undo that and just have it return the true value without the sign no matter what, but then the tests would have to be changed when comparing against NumPy. |
I've pasted an example failure below; it's in the comparison with numpy's output. The commonality of all the failures is:
I'll have some time later today to help find the root of the issue. If you want to look into it before then, I'd focus on inputs containing some inf and nan values, given the random number generator the test is using in this case (as written, none of the CPU tests are covering inputs with nans and infs). [device=GPU] LaxBackedScipyTests.testLogSumExp_shapes=float32[(2, 1, 4),(2, 1, 4)]_axis=0_keepdims=False_return_sign=True_use_b_False
Traceback (most recent call last):
File "<embedded stdlib>/unittest/case.py", line 59, in testPartExecutor
yield
File "<embedded stdlib>/unittest/case.py", line 605, in run
testMethod()
File "/build/work/474fee93c453555e4cdf942d5ea3c8d371f2/google3/runfiles/google3/third_party/py/absl/testing/parameterized.py", line 282, in bound_param_test
return test_method(self, **testcase_params)
File "/build/work/474fee93c453555e4cdf942d5ea3c8d371f2/google3/runfiles/google3/third_party/py/jax/test_util.py", line 365, in test_method_wrapper
return test_method(self, *args, **kwargs)
File "<embedded stdlib>/contextlib.py", line 52, in inner
return func(*args, **kwds)
File "/build/work/474fee93c453555e4cdf942d5ea3c8d371f2/google3/runfiles/google3/third_party/py/jax/tests/lax_scipy_test.py", line 142, in testLogSumExp
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
File "/build/work/474fee93c453555e4cdf942d5ea3c8d371f2/google3/runfiles/google3/third_party/py/jax/test_util.py", line 831, in _CheckAgainstNumpy
canonicalize_dtypes=canonicalize_dtypes)
File "/build/work/474fee93c453555e4cdf942d5ea3c8d371f2/google3/runfiles/google3/third_party/py/jax/test_util.py", line 756, in assertAllClose
rtol=rtol, canonicalize_dtypes=canonicalize_dtypes)
File "/build/work/474fee93c453555e4cdf942d5ea3c8d371f2/google3/runfiles/google3/third_party/py/jax/test_util.py", line 763, in assertAllClose
self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol)
File "/build/work/474fee93c453555e4cdf942d5ea3c8d371f2/google3/runfiles/google3/third_party/py/jax/test_util.py", line 730, in assertArraysAllClose
_assert_numpy_allclose(x, y, atol=atol, rtol=rtol)
File "/build/work/474fee93c453555e4cdf942d5ea3c8d371f2/google3/runfiles/google3/third_party/py/jax/test_util.py", line 117, in _assert_numpy_allclose
np.testing.assert_allclose(a, b, **kw)
File "/build/work/474fee93c453555e4cdf942d5ea3c8d371f2/google3/runfiles/google3/third_party/py/numpy/testing/_private/utils.py", line 1501, in assert_allclose
verbose=verbose, header=header, equal_nan=equal_nan)
File "/build/work/474fee93c453555e4cdf942d5ea3c8d371f2/google3/runfiles/google3/third_party/py/numpy/testing/_private/utils.py", line 757, in assert_array_compare
flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan')
File "/build/work/474fee93c453555e4cdf942d5ea3c8d371f2/google3/runfiles/google3/third_party/py/numpy/testing/_private/utils.py", line 733, in func_assert_same_pos
raise AssertionError(msg)
AssertionError:
Not equal to tolerance rtol=1e-06, atol=1e-06
x and y nan location mismatch:
x: array([[ 1., nan, nan, 1.]], dtype=float32)
y: array([[1., 1., 1., 1.]], dtype=float32) |
So if the input contains -inf values, it actually should still give finite values. I'll take a look. |
Here's a short example of where this implementation differs from scipy: import numpy as np
import scipy.special as osp_special
from jax.scipy.special import logsumexp # in this branch
np.random.seed(0)
x = np.random.rand(3, 4)
x[x < 0.4] = np.nan
arr1, sign1 = logsumexp(jnp.array(x), axis=0, return_sign=True)
arr2, sign2 = osp_special.logsumexp(x, axis=0, return_sign=True)
print(x)
print(sign1)
print(sign2)
|
Let's try this again...
…On Tue, Jun 23, 2020 at 5:38 PM Jake Vanderplas ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In jax/scipy/special.py
<#3488 (comment)>:
> dims = _reduction_dims(a, axis)
dimadd = lambda x: lax.expand_dims(x, dims)
amax = lax.reduce(a, _constant_like(a, -np.inf), lax.max, dims)
amax = lax.stop_gradient(lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
amax_singletons = dimadd(amax)
- out = lax.add(lax.log(lax.reduce(lax.exp(lax.sub(a, amax_singletons)),
- _constant_like(a, 0), lax.add, dims)), amax)
+ if b is None:
+ out = lax.add(lax.log(lax.reduce(lax.exp(lax.sub(a, amax_singletons)),
+ _constant_like(a, 0), lax.add, dims)), amax)
+ sign = lax.stop_gradient(lax.sign(out))
Shoot, given the test failures I think I misled you here. We need the sign
of exp(out), not the sign of out. I think this should probably be
something like this:
sign = jnp.where(jnp.isnan(out), np.nan, 1.0)
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<#3488 (review)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AABDACGQ6LYCYLSKBPDQQTTRYDK6PANCNFSM4OB7H6LA>
.
|
Getting closer... there's still an issue when the data contains large negative numbers: your implementation returns import numpy as np
import scipy.special as osp_special
from jax.scipy.special import logsumexp
np.random.seed(0)
x = np.random.rand(3, 4)
x[:, 1] = -np.inf
arr1, sign1 = logsumexp(jnp.array(x), axis=0, return_sign=True)
arr2, sign2 = osp_special.logsumexp(x, axis=0, return_sign=True)
print(sign1)
print(sign2)
(Side-note: like the previous failures, these come up in the GPU/TPU tests and not the CPU tests, because the current test uses a different random number generator on CPU than on GPU/TPU. I'm not sure why that's the case, but you might look into changing the test to use |
Honestly, the tests pass fine when I swap in |
This is probably because a failing case is not generated with the default |
It looks like this does it! Thanks for the contribution! |
This resolves Issue #3487