Skip to content

Conversation

skye
Copy link
Member

@skye skye commented Sep 19, 2019

All participating hosts are assumed to be running the same pmap code. Conceptually, this can be considered a single pmap over an array sharded on its leading pmapped dimension across the hosts. Each host passes its input shard to its pmapped function call, which returns the corresponding output shard (i.e. an array of the same leading dimension size). However, any collective operations will be run across the entire "global" array.

If the devices argument to pmap is None, the pmap is assumed to be running across all hosts visible to XLA (as returned by jax.host_count()). Each host can pass in an input array of leading dimension size equal to or less than the number of devices local to that host. Note that this doesn't change the current behavior for single-host platforms. If devices are specified, the participating hosts are dictated by the devices' host_ids, and each host must pass in an input array of leading dim size equal to the number of local participating devices.

Implementation-wise, each host independently compiles the computation, which we assume yields the same executable on all hosts (follow-up work will add more error checking). The hosts must know the global axis size of the sharded array, e.g. to provide the correct replica count to XLA. This is equal to the length of devices if specified, but if not, pmap is recursively called (with devices specified) to use psum to compute the global axis size.

Copy link
Collaborator

@hawkinsp hawkinsp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should ideally get a review from Matt if he can.

WrapHashably, Hashable, prod, split_list)
from .lib.xla_bridge import (canonicalize_dtype, device_count,
local_device_count, devices, host_id)
local_device_count, devices, host_id, host_count)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, not new: now we are importing 5 different things from xla_bridge, I think I'd prefer we just imported the module and qualified the names.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are imported to bring them into the jax namespace, e.g. so you can call jax.host_count(). I would prefer we dealt with this in a different way, but haven't gotten around to trying anything...

_get_global_axis_size_pmapped = None

def _get_global_axis_size(local_axis_size):
"""Uses pmap to sum `local_axis_size` across all hosts.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this approach is cunning, don't we know the global topology from the low-level runtime stack? Can't we use that here instead?

In particular the module dependency structure seems a bit odd to me with this design.

It also feels strange to me that we have to special case multihost computations here; it feels to me like the same code should work in both single host and multihost cases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, I got rid of this for now and instead require that multihost pmaps run on all devices or have devices specified.

@skye skye force-pushed the multihost branch 2 times, most recently from 77ce8f6 to 1f36e3d Compare September 26, 2019 01:42
Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Clear explanations, simple global_axis_size implementation, plus logging and error checking. Thanks for adding this!

@skye skye force-pushed the multihost branch 2 times, most recently from 72512f9 to 9dd43b3 Compare September 26, 2019 19:38
All participating hosts are assumed to be running the same pmap
code. Conceptually, this can be considered a single pmap over an array
sharded on its leading pmapped dimension across the hosts. Each host
passes its input shard to its pmapped function call, which returns the
corresponding output shard (i.e. an array of the same leading
dimension size). However, any collective operations will be run across
the entire "global" array.

If the `devices` argument to pmap is None, the pmap is assumed to be
running across all hosts visible to XLA (as returned by
jax.host_count()). Each host can pass in an input array of leading
dimension size equal to or less than the number of devices local to
that host. Note that this doesn't change the current behavior for
single-host platforms. If `devices` are specified, the participating
hosts are dictated by the devices' host_ids, and each host must pass
in an input array of leading dim size equal to the number of local
participating devices.

Implementation-wise, each host independently compiles the computation,
which we assume yields the same executable on all hosts (follow-up
work will add more error checking). The hosts must know the global
axis size of the sharded array, e.g. to provide the correct replica
count to XLA. This is equal to the length of `devices` if specified,
but if not, pmap is recursively called (with `devices` specified) to
use `psum` to compute the global axis size.
@skye
Copy link
Member Author

skye commented Sep 26, 2019

Thanks Matt!

FYI I reworked the multi-host pmap example in the pmap docstring based on offline suggestions from @necula01. I'm gonna submit now as-is, but if anyone has further comments, I can iterate in subsequent commits.

@skye skye merged commit dc2ee0d into jax-ml:master Sep 26, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants