-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Add support for multihost pmaps. #1376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should ideally get a review from Matt if he can.
WrapHashably, Hashable, prod, split_list) | ||
from .lib.xla_bridge import (canonicalize_dtype, device_count, | ||
local_device_count, devices, host_id) | ||
local_device_count, devices, host_id, host_count) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit, not new: now we are importing 5 different things from xla_bridge
, I think I'd prefer we just imported the module and qualified the names.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are imported to bring them into the jax
namespace, e.g. so you can call jax.host_count()
. I would prefer we dealt with this in a different way, but haven't gotten around to trying anything...
jax/interpreters/pxla.py
Outdated
_get_global_axis_size_pmapped = None | ||
|
||
def _get_global_axis_size(local_axis_size): | ||
"""Uses pmap to sum `local_axis_size` across all hosts. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While this approach is cunning, don't we know the global topology from the low-level runtime stack? Can't we use that here instead?
In particular the module dependency structure seems a bit odd to me with this design.
It also feels strange to me that we have to special case multihost computations here; it feels to me like the same code should work in both single host and multihost cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline, I got rid of this for now and instead require that multihost pmaps run on all devices or have devices
specified.
77ce8f6
to
1f36e3d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Clear explanations, simple global_axis_size
implementation, plus logging and error checking. Thanks for adding this!
72512f9
to
9dd43b3
Compare
All participating hosts are assumed to be running the same pmap code. Conceptually, this can be considered a single pmap over an array sharded on its leading pmapped dimension across the hosts. Each host passes its input shard to its pmapped function call, which returns the corresponding output shard (i.e. an array of the same leading dimension size). However, any collective operations will be run across the entire "global" array. If the `devices` argument to pmap is None, the pmap is assumed to be running across all hosts visible to XLA (as returned by jax.host_count()). Each host can pass in an input array of leading dimension size equal to or less than the number of devices local to that host. Note that this doesn't change the current behavior for single-host platforms. If `devices` are specified, the participating hosts are dictated by the devices' host_ids, and each host must pass in an input array of leading dim size equal to the number of local participating devices. Implementation-wise, each host independently compiles the computation, which we assume yields the same executable on all hosts (follow-up work will add more error checking). The hosts must know the global axis size of the sharded array, e.g. to provide the correct replica count to XLA. This is equal to the length of `devices` if specified, but if not, pmap is recursively called (with `devices` specified) to use `psum` to compute the global axis size.
Thanks Matt! FYI I reworked the multi-host pmap example in the pmap docstring based on offline suggestions from @necula01. I'm gonna submit now as-is, but if anyone has further comments, I can iterate in subsequent commits. |
All participating hosts are assumed to be running the same pmap code. Conceptually, this can be considered a single pmap over an array sharded on its leading pmapped dimension across the hosts. Each host passes its input shard to its pmapped function call, which returns the corresponding output shard (i.e. an array of the same leading dimension size). However, any collective operations will be run across the entire "global" array.
If the
devices
argument to pmap is None, the pmap is assumed to be running across all hosts visible to XLA (as returned by jax.host_count()). Each host can pass in an input array of leading dimension size equal to or less than the number of devices local to that host. Note that this doesn't change the current behavior for single-host platforms. Ifdevices
are specified, the participating hosts are dictated by the devices' host_ids, and each host must pass in an input array of leading dim size equal to the number of local participating devices.Implementation-wise, each host independently compiles the computation, which we assume yields the same executable on all hosts (follow-up work will add more error checking). The hosts must know the global axis size of the sharded array, e.g. to provide the correct replica count to XLA. This is equal to the length of
devices
if specified, but if not, pmap is recursively called (withdevices
specified) to usepsum
to compute the global axis size.