Skip to content

[Bug?] jax.random.shuffle does not behave like np.random.shuffle in terms of shuffling along a given axis #2066

@zmjjmz

Description

@zmjjmz

When we call jax.random.shuffle(key, array, axis=0) on an ndarray of shape (n,m) I'm expecting (based on the docs) that we'll end up with an new array where every row has the same order internally but the order of the rows is shuffled. I'm not super clear that this is guaranteed by the docs, which say:

Shuffle the elements of an array uniformly at random along an axis.

But the numpy docs for numpy.random.shuffle do make this super clear, saying that

This function only shuffles the array along the first axis of a multi-dimensional array. The order of sub-arrays is changed but their contents remains the same.

So I'm not super sure the intent is for the Jax behavior to mimic the numpy behavior here, but that is what I was expecting. To demonstrate / reproduce this difference in behavior, see this gist

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions