Skip to content

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Jun 14, 2023

Fixes #4013. Why now? This will aid in the transition to custom PRNG ( #9263) because it allows split keys to be reshaped without having to account for the intrinsic key dimension.

Suppose you'd like to create six keys, reshaped into batch dimensions of shape (2, 3).

Before custom PRNG, you could write something like this, assuming you're using the default PRNG impl:

keys = random.split(key, 6).reshape(2, 3, 2)

With custom PRNG, it must change to this:

keys = random.split(key, 6).reshape(2, 3)

To write this operation in a way that is compatible with both requires something like this, which is somewhat ugly/unclear:

keys = random.split(key, 6).reshape(2, 3, *key.shape)

With this PR, we can instead write this and be compatible with both cases:

keys = random.split(key, shape=(2, 3))

Regarding the API chosen here: I considered overloading the num parameter to also accept a tuple of integers, but I judged it cleaner to avoid that kind of polymorphism. I made shape a keyword-only argument for clarity.

@NeilGirdhar
Copy link
Contributor

This is complete now, right?

@froystig
Copy link
Member

Yeah, via #16644

@jakevdp
Copy link
Collaborator Author

jakevdp commented Aug 30, 2023

Yep, thanks!

@jakevdp jakevdp closed this Aug 30, 2023
@jakevdp jakevdp deleted the split-shape branch August 30, 2023 22:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

random.split() should allow for passing in a desired shape

3 participants