You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When we call jax.random.shuffle(key, array, axis=0) on an ndarray of shape (n,m) I'm expecting (based on the docs) that we'll end up with an new array where every row has the same order internally but the order of the rows is shuffled. I'm not super clear that this is guaranteed by the docs, which say:
Shuffle the elements of an array uniformly at random along an axis.
But the numpy docs for numpy.random.shuffle do make this super clear, saying that
This function only shuffles the array along the first axis of a multi-dimensional array. The order of sub-arrays is changed but their contents remains the same.
So I'm not super sure the intent is for the Jax behavior to mimic the numpy behavior here, but that is what I was expecting. To demonstrate / reproduce this difference in behavior, see this gist