Skip to content

Add support for lax.pmean #929

@jheek

Description

@jheek

When using pmap I often miss a parallel mean function lax.pmean(x, axis_name).
I guess it's not too complicated to write lax.psum / axis_size. However, there is currently also no way to retrieve the size of a parallel axis within a parallel computation.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or requestquestionQuestions for the JAX team

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions