Skip to content

Conversation

riven314
Copy link

address #10375

@riven314 riven314 changed the title Implementation of Scipy Bootstrap WIP: Implementation of Scipy Bootstrap May 28, 2022
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks like a good start! Mostly small comments, but the larger change I'd consider is to express the resamplings in terms of vmap rather than scan if possible.


from typing import NamedTuple

import scipy.stats as osp_stats
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused import?

idxs = random.randint(rng, shape=(n,), minval=0, maxval=n)
# `sample` is a tuple of sample sets, we need to apply same indexing on each sample set
resample = jax.tree_map(lambda data: data[..., idxs], sample)
next_rng = jax.random.split(rng, 1)[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This split should happen before rng is used to generate random numbers, e.g.

rng, next_rng = jax.random.split(rng)

otherwise you may have subtle correlational bugs in the pseudo-random numbers.

Copy link
Author

@riven314 riven314 Jun 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I have changed to use vmap, so now I do the splitting in the vmap argument: jax.random.split(key, n_resamples)

alpha = jnp.broadcast_to(alpha, shape)
# QUESTION: is it good practice to use vmap here?
# TODO: may need to handle nan
# TODO: handle numeric discrepancy against scipy's _percentile_along_axis
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vmap is fine here, but do you need multiple vmaps if theta_hat_b has more than two dimensions?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should I use a control flow here to apply multiple vmap if I find that theta_hat_b have more than 2 dimensions?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use Python for loops to apply multiple vmap transforms; something like this:

import jax
import jax.numpy as jnp

def f(x):
  return x.sum()

x = jnp.ones((2, 3, 4, 5))
for i in range(x.ndim - 1):
  f = jax.vmap(f)
print(f(x).shape)
# (2, 3, 4)

return next_rng, statistic(*resample)

# xs is dummy simply for the sake of carrying loops
_, theta_hat_b = lax.scan(_resample_and_compute_once, rng, jnp.ones(n))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason you chose scan instead of vmap? Because it's a serial operation, scan is going to be far slower than vmap on accelerators.

Copy link
Author

@riven314 riven314 Jun 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice suggestion!
for _bootstrap_resample_and_compute_statistic, vmap does run faster than scan

return idx + 1, statistic(resample)

# TODO: check if it can handle `statistic` that return multiple scalars
_, theta_hat_i = lax.scan(_jackknife_resample_and_compute, 0, jnp.ones(n))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here regarding scan - we should express this in terms of vmap instead to take advantage of parallelism on accelerators.

Copy link
Author

@riven314 riven314 Jun 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found an issue here @jakevdp

for _jackknife_resample_and_compute_statistic, I weirdly find that vmap runs SLOWER than scan.
On CPU Colab with 10,000 samples, vmap speed is ~210 ms while scan speed is ~110 ms.

Here is the Colab notebook that I did my benchmarking:
https://colab.research.google.com/drive/1abKv-zyI-CZ4BEtXc3oh5Ebx_rcFq0dp?usp=sharing

I am not sure if its because my implementation is inefficient, or I have made some errrors.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! On CPU scan and friends are generally not too bad, but I suspect if you benchmarked on GPU/TPU you'd find scan to be much slower than vmap.

standard_error: jnp.ndarray


def _bootstrap_resample_and_compute_statistic(sample, statistic, n_resamples, key):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be more consistent with JAX's typical approach if (1) key were the first argument, and (2) batching were handled via vmap at the call-site, rather than via an argument passed to the function. So the API would be something like this:

def _bootstrap_resample_and_compute_statistic(key, sample, statistic):
  ...

and rather than calling it like

_bootstrap_resample_and_compute_statistic(sample, statistic, n_resamples, key)

you could instead call it like

keys = random.split(key, n_resamples)
vmap(_bootstrap_resample_and_compute_statistic, (0, None, None))(keys, sample, statistic)

The benefits would be (1) more explicit handling of key splitting by the user of the function, and (2) vmap at the outer levels may be somewhat more efficient (I'm not entirely sure on that, but I think it is the case) and (3) it's more maintainable, because it makes use of JAX's composable transforms in the way they're intended to be used, rather than hiding them behind less flexible batch-aware APIs.

miss_first_sample = sample[1:]
miss_last_sample = sample[:-1]

@vmap
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here. Can we define _jackknife_resample_and_compute_statistic so it natively handles a single batch, and then use vmap as appropriate at the call-site?

alpha = jnp.broadcast_to(alpha, shape)
vmap_percentile = jnp.percentile
for i in range(theta_hat_b.ndim - 1):
vmap_percentile = vmap(vmap_percentile)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than vmapping, can we use the axis argument to jnp.percentile?

for i in range(theta_hat_b.ndim - 1):
vmap_percentile = vmap(vmap_percentile)
percentiles = vmap_percentile(theta_hat_b, alpha)
return percentiles[()]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the purpose of empty indexing here.

# check alpha is jax array type


if vectorized not in (True, False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically APIs don't require object identity for boolean values.

@_wraps(
scipy.stats.bootstrap,
lax_description=_replace_random_state_by_key_no_batch_jnp_statistic_doc,
skip_params=("batch",),
Copy link
Collaborator

@jakevdp jakevdp Jun 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use extra_params here to document the key argument.

That said, I'm starting to wonder if this should really be considered a wrapper of scipy.stats.bootstrap, because its API is now substantially different. In numpy's case, we don't provide any wrappers for numpy.random functionality, instead using a different key-based API in jax.random. I'm starting to think that the same treatment would make sense here, because as written jax.scipy.bootstrap must be called with a different signature than scipy.bootstrap.

It also would solve the issue of how to handle irrelevant params like vectorized, and we could write the API in a way that is more typical of JAX library functions (i.e. keep batching orthogonal to the implementation, rather than calling vmap within.

What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jakevdp Is there a decision/consensus on whether to adhere to the original API?

@wonhyeongseo
Copy link
Contributor

Hello, @riven314!
I'm interested in this PR, may I please address @jakevdp's reviews?
Thank you, and I hope you have a wonderful day!

@riven314
Copy link
Author

riven314 commented Aug 23, 2022

Hi @wonhyeongseo
Thanks for your interest! Currently I don't have bandwidth on this ticket at least for these 2 months, so feel free to work on it! Otherwise, I can continue working on it whenever I retain my bandwidth.

A few notes about my PR (you can further discuss with @jakevdp to see if those points are valid to address):

  1. the numerical discrepancy between my JAX implementation v.s. scipy's is NOT immaterial, despite ramping up the resampling number. Can't give you a precise statistics (discrepancy of ~0.01?) because I coded it a few months ago, but such discrepancy could probably make unit test to fail
  2. my JAX implementation is a bit slower than scipy's in some settings (CPU).
  3. Scipy bootstrap supports a lot of different scenarios (e.g. multi-statistics, one-set/ multi-set/ paired samples, different bootstrap methods). For now I am not confident if my implementation works as expected under various scenarios.

Look forward to your contribution to the ticket.

@carlosgmartin
Copy link
Contributor

@wonhyeongseo Are you still interested in tackling this?

@wonhyeongseo
Copy link
Contributor

wonhyeongseo commented Feb 3, 2023

Hello @carlosgmartin ! Not at the moment because I don't know how. Would love to see your implementation! 😊

@carlosgmartin
Copy link
Contributor

@wonhyeongseo Thanks for letting us know.

@riven314 Would you like to continue working on your PR, or would you prefer someone else take over?

@riven314
Copy link
Author

riven314 commented Feb 3, 2023

@carlosgmartin
Hi! I don't foresee I have bandwidth in the short run so I would appreciate if someone is willing to take over.
I am glad to explain my code if any help is needed!

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 3, 2023

Thanks for working on this – I think given the discussion in https://jax.readthedocs.io/en/latest/jep/12049-type-annotations.html this would now be considered out-of-scope for JAX itself. Sorry we weren't able to merge your contribution!

@jakevdp jakevdp closed this Nov 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants