Skip to content

Commit dc2ee0d

Browse files
committed
Add support for multihost pmaps.
All participating hosts are assumed to be running the same pmap code. Conceptually, this can be considered a single pmap over an array sharded on its leading pmapped dimension across the hosts. Each host passes its input shard to its pmapped function call, which returns the corresponding output shard (i.e. an array of the same leading dimension size). However, any collective operations will be run across the entire "global" array. If the `devices` argument to pmap is None, the pmap is assumed to be running across all hosts visible to XLA (as returned by jax.host_count()). Each host can pass in an input array of leading dimension size equal to or less than the number of devices local to that host. Note that this doesn't change the current behavior for single-host platforms. If `devices` are specified, the participating hosts are dictated by the devices' host_ids, and each host must pass in an input array of leading dim size equal to the number of local participating devices. Implementation-wise, each host independently compiles the computation, which we assume yields the same executable on all hosts (follow-up work will add more error checking). The hosts must know the global axis size of the sharded array, e.g. to provide the correct replica count to XLA. This is equal to the length of `devices` if specified, but if not, pmap is recursively called (with `devices` specified) to use `psum` to compute the global axis size.
1 parent e7b09cf commit dc2ee0d

File tree

3 files changed

+83
-27
lines changed

3 files changed

+83
-27
lines changed

jax/api.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
from .util import (unzip2, unzip3, curry, partial, safe_map, safe_zip,
5252
WrapHashably, Hashable, prod, split_list)
5353
from .lib.xla_bridge import (canonicalize_dtype, device_count,
54-
local_device_count, devices, host_id)
54+
local_device_count, devices, host_id, host_count)
5555
from .abstract_arrays import ShapedArray, raise_to_shaped
5656
from .interpreters import partial_eval as pe
5757
from .interpreters import xla
@@ -643,23 +643,36 @@ def pmap(fun, axis_name=None, devices=None, backend=None):
643643
pure maps, ``pmap`` enables the use of parallel SPMD collective operations,
644644
like all-reduce sum.
645645
646-
The mapped axis size must be less than or equal to the number of XLA devices
647-
available (unless ``devices`` is specified, see below). For nested ``pmap``
648-
calls, the product of the mapped axis sizes must be less than or equal to the
649-
number of XLA devices.
646+
The mapped axis size must be less than or equal to the number of local XLA
647+
devices available, as returned by ``jax.local_device_count()`` (unless
648+
``devices`` is specified, see below). For nested ``pmap`` calls, the product
649+
of the mapped axis sizes must be less than or equal to the number of XLA
650+
devices. TODO(skye): support < # local devices on multi-host platforms
651+
652+
**Multi-host platforms:** On multi-host platforms such as TPU pods, ``pmap``
653+
is designed to be used in SPMD Python programs, where every host is running
654+
the same Python code such that all hosts run the same pmapped function in the
655+
same order. Each host should still call the pmapped function with mapped axis
656+
size equal to the number of *local* devices (unless ``devices`` is specified,
657+
see below), and an array of the same leading axis size will be returned as
658+
usual. However, any collective operations in ``fun`` will be computed over
659+
*all* participating devices, including those on other hosts, via
660+
device-to-device communication. Conceptually, this can be thought of as
661+
running a pmap over a single array sharded across hosts, where each host
662+
"sees" only its local shard of the input and output.
650663
651664
Args:
652665
fun: Function to be mapped over argument axes.
653666
axis_name: Optional, a hashable Python object used to identify the mapped
654667
axis so that parallel collectives can be applied.
655668
devices: This is an experimental feature and the API is likely to change.
656669
Optional, a sequence of Devices to map over. (Available devices can be
657-
retrieved via jax.devices()). If specified, the length of the sequence
658-
must be equal to the size of the mapped axis. Nested ``pmap``s with
659-
``devices`` specified in either the inner or outer ``pmap`` are not yet
660-
supported.
670+
retrieved via jax.devices()). If specified, the size of the mapped axis
671+
must be equal to the number of local devices in the sequence. Nested
672+
``pmap`` s with ``devices`` specified in either the inner or outer ``pmap``
673+
are not yet supported.
661674
backend: This is an experimental feature and the API is likely to change.
662-
Optional, a string representing the xla backend. 'cpu','gpu', or 'tpu'.
675+
Optional, a string representing the xla backend. 'cpu', 'gpu', or 'tpu'.
663676
664677
Returns:
665678
A parallelized version of ``fun`` with arguments that correspond to those of
@@ -721,10 +734,28 @@ def pmap(fun, axis_name=None, devices=None, backend=None):
721734
>>> print(doubly_normed.sum((0, 1)))
722735
1.0
723736
737+
On multi-host platforms, collective operations operate over all devices,
738+
including those those on other hosts. For example, assuming the following code
739+
runs on two hosts with 4 XLA devices each:
740+
741+
>>> f = lambda x: x + jax.lax.psum(x, axis_name='i')
742+
>>> data = np.arange(4) if jax.host_id() == 0 else np.arange(4,8)
743+
>>> out = pmap(f, axis_name='i')(data)
744+
>>> print(out)
745+
[28 29 30 31] # on host 0
746+
[32 33 34 35] # on host 1
747+
748+
Each host passes in a different length-4 array, corresponding to its 4 local
749+
devices, and the psum operates over all 8 values. Conceptually, the two
750+
length-4 arrays can be thought of as sharded length-16 array (in this example
751+
equivalent to np.arange(8)) that is mapped over, with the length-8 mapped axis
752+
given name 'i'. The pmap call on each host then returns the corresponding
753+
length-4 output shard.
754+
724755
The ``devices`` argument can be used to specify exactly which devices are used
725-
to run the parallel computation. For example, the following code defines
726-
two parallel computations, one which runs on the first six devices and one on
727-
the remaining two:
756+
to run the parallel computation. For example, again assuming a single host
757+
with 8 devices, the following code defines two parallel computations, one
758+
which runs on the first six devices and one on the remaining two:
728759
729760
>>> from functools import partial
730761
>>> @partial(pmap, axis_name='i', devices=jax.devices()[:6])

jax/interpreters/pxla.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import operator as op
2323
import threading
2424

25+
from absl import logging
2526
import numpy as onp
2627
import six
2728
from six.moves import reduce
@@ -172,32 +173,37 @@ def replica_groups(nrep, mesh_spec, mesh_axes):
172173

173174
### the main pmap machinery lowers SPMD jaxprs to multi-replica XLA computations
174175

175-
def compile_replicated(jaxpr, backend, axis_name, axis_size, devices, consts,
176-
tuple_args, *abstract_args):
176+
def compile_replicated(jaxpr, backend, axis_name, axis_size, global_axis_size,
177+
devices, consts, tuple_args, *abstract_args):
178+
jaxpr_replicas = xla.jaxpr_replicas(jaxpr)
179+
num_local_replicas = axis_size * jaxpr_replicas
180+
num_replicas = global_axis_size * jaxpr_replicas
181+
logging.vlog(
182+
1, "compile_replicated: axis_size=%d global_axis_size=%d jaxpr_replicas=%d"
183+
% (axis_size, global_axis_size, jaxpr_replicas))
184+
177185
if devices is None:
178-
num_replicas = axis_size * xla.jaxpr_replicas(jaxpr)
179186
if num_replicas > xb.device_count(backend):
180187
msg = ("compiling computation that requires {} replicas, but only {} XLA "
181188
"devices are available")
182189
raise ValueError(msg.format(num_replicas, xb.device_count(backend)))
183190
device_assignment = None
184191
else:
185-
assert all(d.host_id == xb.host_id() for d in devices)
186-
if axis_size != len(devices):
192+
assert any(d.host_id == xb.host_id() for d in devices)
193+
if num_replicas != len(devices):
187194
raise ValueError("compiling computation that requires %s replicas, "
188195
"but %s devices were specified"
189-
% (axis_size, len(devices)))
190-
num_replicas = len(devices)
196+
% (num_replicas, len(devices)))
191197
device_assignment = tuple(d.id for d in devices)
192198

193-
axis_env = xla.AxisEnv(num_replicas, [axis_name], [axis_size], devices)
199+
axis_env = xla.AxisEnv(num_replicas, [axis_name], [global_axis_size], devices)
194200
arg_shapes = list(map(aval_to_xla_shape, abstract_args))
195201
built_c = xla.jaxpr_computation(jaxpr, backend, axis_env, consts, (), arg_shapes,
196202
tuple_args=tuple_args)
197203
compiled = built_c.Compile(
198204
compile_options=xb.get_compile_options(num_replicas, device_assignment),
199205
backend=xb.get_backend(backend))
200-
return compiled, num_replicas
206+
return compiled, num_local_replicas
201207

202208

203209
### applying parallel primitives in op-by-op Python dispatch
@@ -446,9 +452,24 @@ def parallel_callable(fun, backend, axis_name, axis_size, devices, *avals):
446452
pvals = [PartialVal((aval, core.unit)) for aval in avals]
447453
pval = PartialVal([core.abstract_unit, core.unit]) # dummy value
448454

455+
if devices:
456+
global_axis_size = len(devices)
457+
elif xb.host_count() > 1:
458+
# TODO(skye): relax this constraint or provide functionality for
459+
# automatically passing appropriate `devices`.
460+
if axis_size != xb.local_device_count():
461+
raise ValueError(
462+
"On multi-host platforms, the input to pmapped functions must have "
463+
"leading axis size equal to the number of local devices if no "
464+
"`devices` argument is specified. Got axis_size=%d, "
465+
"num_local_devices=%d" % (axis_size, xb.local_device_count()))
466+
global_axis_size = xb.device_count()
467+
else:
468+
global_axis_size = axis_size
469+
449470
@lu.wrap_init
450471
def dynamic_fun(dummy, *args):
451-
with extend_dynamic_axis_env(axis_name, dummy.trace, axis_size):
472+
with extend_dynamic_axis_env(axis_name, dummy.trace, global_axis_size):
452473
return fun.call_wrapped(*args)
453474

454475
with core.new_master(JaxprTrace, True) as master:
@@ -470,7 +491,8 @@ def dynamic_fun(dummy, *args):
470491
# Condense many arguments into single tuple argument to avoid a TPU issue.
471492
tuple_args = len(avals) > 100
472493
compiled, nrep = compile_replicated(jaxpr, backend, axis_name, axis_size,
473-
devices, consts, tuple_args, *avals)
494+
global_axis_size, devices, consts,
495+
tuple_args, *avals)
474496
device_ordinals = compiled.DeviceOrdinals()
475497
assignments = assign_shards_to_replicas(nrep, axis_size)
476498
handle_args = partial(shard_args, backend, device_ordinals, assignments,
@@ -520,7 +542,7 @@ def _pmap_translation_rule(c, jaxpr, axis_env, const_nodes, freevar_nodes,
520542
if axis_env.devices is not None or (axis_env.names and devices is not None):
521543
raise ValueError("Nested pmaps with explicit devices argument.")
522544
new_env = xla.extend_axis_env(axis_env, axis_name, axis_size)
523-
in_nodes_sharded = list(map(partial(_xla_shard, c, new_env.sizes), in_nodes))
545+
in_nodes_sharded = list(map(partial(_xla_shard, c), in_nodes))
524546
sharded_outs = xla.jaxpr_subcomp(c, jaxpr, backend, new_env, const_nodes,
525547
freevar_nodes, *in_nodes_sharded)
526548
outs = [_xla_unshard(c, xla.axis_groups(new_env, axis_name), r)
@@ -531,14 +553,13 @@ def _pmap_translation_rule(c, jaxpr, axis_env, const_nodes, freevar_nodes,
531553
ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p)
532554
pe.map_primitives.add(xla_pmap_p)
533555

534-
def _xla_shard(c, sizes, x):
556+
def _xla_shard(c, x):
535557
xla_shape = c.GetShape(x)
536558
if xla_shape.is_tuple():
537559
assert not xla_shape.tuple_shapes()
538560
return x
539561
else:
540562
dims = list(xla_shape.dimensions())
541-
assert dims[0] == sizes[-1]
542563
start_indices = _xla_shard_start_indices(c, dims[0], len(dims))
543564
return c.Reshape(c.DynamicSlice(x, start_indices, [1] + dims[1:]),
544565
None, dims[1:])

jax/lib/xla_bridge.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,10 @@ def host_id(backend=None):
197197
return get_backend(backend).host_id()
198198

199199

200+
def host_count():
201+
return len(set(d.host_id for d in devices()))
202+
203+
200204
### utility functions
201205

202206
@util.memoize

0 commit comments

Comments
 (0)