-
Notifications
You must be signed in to change notification settings - Fork 3.2k
WIP: Implementation of Scipy Bootstrap #10871
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this looks like a good start! Mostly small comments, but the larger change I'd consider is to express the resamplings in terms of vmap
rather than scan
if possible.
jax/_src/scipy/stats/bootstrap.py
Outdated
|
||
from typing import NamedTuple | ||
|
||
import scipy.stats as osp_stats |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused import?
jax/_src/scipy/stats/bootstrap.py
Outdated
idxs = random.randint(rng, shape=(n,), minval=0, maxval=n) | ||
# `sample` is a tuple of sample sets, we need to apply same indexing on each sample set | ||
resample = jax.tree_map(lambda data: data[..., idxs], sample) | ||
next_rng = jax.random.split(rng, 1)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This split should happen before rng
is used to generate random numbers, e.g.
rng, next_rng = jax.random.split(rng)
otherwise you may have subtle correlational bugs in the pseudo-random numbers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I have changed to use vmap
, so now I do the splitting in the vmap argument: jax.random.split(key, n_resamples)
jax/_src/scipy/stats/bootstrap.py
Outdated
alpha = jnp.broadcast_to(alpha, shape) | ||
# QUESTION: is it good practice to use vmap here? | ||
# TODO: may need to handle nan | ||
# TODO: handle numeric discrepancy against scipy's _percentile_along_axis |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vmap is fine here, but do you need multiple vmaps if theta_hat_b
has more than two dimensions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should I use a control flow here to apply multiple vmap
if I find that theta_hat_b
have more than 2 dimensions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use Python for
loops to apply multiple vmap
transforms; something like this:
import jax
import jax.numpy as jnp
def f(x):
return x.sum()
x = jnp.ones((2, 3, 4, 5))
for i in range(x.ndim - 1):
f = jax.vmap(f)
print(f(x).shape)
# (2, 3, 4)
jax/_src/scipy/stats/bootstrap.py
Outdated
return next_rng, statistic(*resample) | ||
|
||
# xs is dummy simply for the sake of carrying loops | ||
_, theta_hat_b = lax.scan(_resample_and_compute_once, rng, jnp.ones(n)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason you chose scan instead of vmap? Because it's a serial operation, scan is going to be far slower than vmap on accelerators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice suggestion!
for _bootstrap_resample_and_compute_statistic
, vmap
does run faster than scan
jax/_src/scipy/stats/bootstrap.py
Outdated
return idx + 1, statistic(resample) | ||
|
||
# TODO: check if it can handle `statistic` that return multiple scalars | ||
_, theta_hat_i = lax.scan(_jackknife_resample_and_compute, 0, jnp.ones(n)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here regarding scan
- we should express this in terms of vmap
instead to take advantage of parallelism on accelerators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found an issue here @jakevdp
for _jackknife_resample_and_compute_statistic
, I weirdly find that vmap
runs SLOWER than scan
.
On CPU Colab with 10,000 samples, vmap
speed is ~210 ms while scan
speed is ~110 ms.
Here is the Colab notebook that I did my benchmarking:
https://colab.research.google.com/drive/1abKv-zyI-CZ4BEtXc3oh5Ebx_rcFq0dp?usp=sharing
I am not sure if its because my implementation is inefficient, or I have made some errrors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting! On CPU scan
and friends are generally not too bad, but I suspect if you benchmarked on GPU/TPU you'd find scan
to be much slower than vmap
.
standard_error: jnp.ndarray | ||
|
||
|
||
def _bootstrap_resample_and_compute_statistic(sample, statistic, n_resamples, key): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be more consistent with JAX's typical approach if (1) key
were the first argument, and (2) batching were handled via vmap
at the call-site, rather than via an argument passed to the function. So the API would be something like this:
def _bootstrap_resample_and_compute_statistic(key, sample, statistic):
...
and rather than calling it like
_bootstrap_resample_and_compute_statistic(sample, statistic, n_resamples, key)
you could instead call it like
keys = random.split(key, n_resamples)
vmap(_bootstrap_resample_and_compute_statistic, (0, None, None))(keys, sample, statistic)
The benefits would be (1) more explicit handling of key splitting by the user of the function, and (2) vmap
at the outer levels may be somewhat more efficient (I'm not entirely sure on that, but I think it is the case) and (3) it's more maintainable, because it makes use of JAX's composable transforms in the way they're intended to be used, rather than hiding them behind less flexible batch-aware APIs.
miss_first_sample = sample[1:] | ||
miss_last_sample = sample[:-1] | ||
|
||
@vmap |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here. Can we define _jackknife_resample_and_compute_statistic
so it natively handles a single batch, and then use vmap
as appropriate at the call-site?
alpha = jnp.broadcast_to(alpha, shape) | ||
vmap_percentile = jnp.percentile | ||
for i in range(theta_hat_b.ndim - 1): | ||
vmap_percentile = vmap(vmap_percentile) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than vmapping, can we use the axis
argument to jnp.percentile
?
for i in range(theta_hat_b.ndim - 1): | ||
vmap_percentile = vmap(vmap_percentile) | ||
percentiles = vmap_percentile(theta_hat_b, alpha) | ||
return percentiles[()] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the purpose of empty indexing here.
# check alpha is jax array type | ||
|
||
|
||
if vectorized not in (True, False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typically APIs don't require object identity for boolean values.
@_wraps( | ||
scipy.stats.bootstrap, | ||
lax_description=_replace_random_state_by_key_no_batch_jnp_statistic_doc, | ||
skip_params=("batch",), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use extra_params
here to document the key
argument.
That said, I'm starting to wonder if this should really be considered a wrapper of scipy.stats.bootstrap
, because its API is now substantially different. In numpy's case, we don't provide any wrappers for numpy.random
functionality, instead using a different key-based API in jax.random
. I'm starting to think that the same treatment would make sense here, because as written jax.scipy.bootstrap
must be called with a different signature than scipy.bootstrap
.
It also would solve the issue of how to handle irrelevant params like vectorized
, and we could write the API in a way that is more typical of JAX library functions (i.e. keep batching orthogonal to the implementation, rather than calling vmap
within.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jakevdp Is there a decision/consensus on whether to adhere to the original API?
Hi @wonhyeongseo A few notes about my PR (you can further discuss with @jakevdp to see if those points are valid to address):
Look forward to your contribution to the ticket. |
@wonhyeongseo Are you still interested in tackling this? |
Hello @carlosgmartin ! Not at the moment because I don't know how. Would love to see your implementation! 😊 |
@wonhyeongseo Thanks for letting us know. @riven314 Would you like to continue working on your PR, or would you prefer someone else take over? |
@carlosgmartin |
Thanks for working on this – I think given the discussion in https://jax.readthedocs.io/en/latest/jep/12049-type-annotations.html this would now be considered out-of-scope for JAX itself. Sorry we weren't able to merge your contribution! |
address #10375