You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
All participating hosts are assumed to be running the same pmap
code. Conceptually, this can be considered a single pmap over an array
sharded on its leading pmapped dimension across the hosts. Each host
passes its input shard to its pmapped function call, which returns the
corresponding output shard (i.e. an array of the same leading
dimension size). However, any collective operations will be run across
the entire "global" array.
If the `devices` argument to pmap is None, the pmap is assumed to be
running across all hosts visible to XLA (as returned by
jax.host_count()). Each host can pass in an input array of leading
dimension size equal to or less than the number of devices local to
that host. Note that this doesn't change the current behavior for
single-host platforms. If `devices` are specified, the participating
hosts are dictated by the devices' host_ids, and each host must pass
in an input array of leading dim size equal to the number of local
participating devices.
Implementation-wise, each host independently compiles the computation,
which we assume yields the same executable on all hosts (follow-up
work will add more error checking). The hosts must know the global
axis size of the sharded array, e.g. to provide the correct replica
count to XLA. This is equal to the length of `devices` if specified,
but if not, pmap is recursively called (with `devices` specified) to
use `psum` to compute the global axis size.
0 commit comments