-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Closed
Labels
P3 (no schedule)We have no plan to work on this and, if it is unassigned, we would be happy to review a PRWe have no plan to work on this and, if it is unassigned, we would be happy to review a PR
Description
The current API is random.split(key, num)
, which returns an array of shape (num, 2)
.
I'd like to have an API of the form random.split(key, shape)
, which returns an array of shape shape + (2,)
, without requiring an additional reshape
. This would make random.split
easier to use for preparing multi-dimensional key, e.g., for use with repeated vmap
.
Some options:
- Write a function with a new name for this API, e.g.,
random.split_with_shape
(this is a bad name) - Overload the
num
argument to allow for either an integer size or a tuple of integersshape
. This is a little tricky to get right, but in principle could be done (NumPy has lot of APIs like this). - Make
num
optional and add a new optionalshape
parameter that could be used instead. - Deprecate
num
and replace it withshape
. This would involve (2) and an extra temporary keyword-only argument. Eventually, we could require either (a) the new argument nameshape
or (b) the new argument nameshape
and passing a tuple of integers instead of an integer.
I would lean towards option (3). It's a little messy to add an extra argument but it is cleaner than overloading random.split
to support either sizes or shapes. The later would be consistent with how numpy.random
works with the size
argument, but not how the rest of jax.random
works.
Metadata
Metadata
Assignees
Labels
P3 (no schedule)We have no plan to work on this and, if it is unassigned, we would be happy to review a PRWe have no plan to work on this and, if it is unassigned, we would be happy to review a PR