Skip to content

random.split() should allow for passing in a desired shape #4013

@shoyer

Description

@shoyer

The current API is random.split(key, num), which returns an array of shape (num, 2).

I'd like to have an API of the form random.split(key, shape), which returns an array of shape shape + (2,), without requiring an additional reshape. This would make random.split easier to use for preparing multi-dimensional key, e.g., for use with repeated vmap.

Some options:

  1. Write a function with a new name for this API, e.g., random.split_with_shape (this is a bad name)
  2. Overload the num argument to allow for either an integer size or a tuple of integers shape. This is a little tricky to get right, but in principle could be done (NumPy has lot of APIs like this).
  3. Make num optional and add a new optional shape parameter that could be used instead.
  4. Deprecate num and replace it with shape. This would involve (2) and an extra temporary keyword-only argument. Eventually, we could require either (a) the new argument name shape or (b) the new argument name shape and passing a tuple of integers instead of an integer.

I would lean towards option (3). It's a little messy to add an extra argument but it is cleaner than overloading random.split to support either sizes or shapes. The later would be consistent with how numpy.random works with the size argument, but not how the rest of jax.random works.

Metadata

Metadata

Assignees

Labels

P3 (no schedule)We have no plan to work on this and, if it is unassigned, we would be happy to review a PR

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions