Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions docs/async_dispatch.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
Asynchronous dispatch
=====================

JAX uses asynchronous dispatch to hide Python overheads. Consider the following
program:

>>> import numpy as onp
>>> from jax import numpy as np
>>> from jax import random
>>> x = random.uniform(random.PRNGKey(0), (1000, 1000))
>>> np.dot(x, x) + 3.
DeviceArray([[258.01971436, 249.64862061, 257.13372803, ...,
236.67948914, 250.68939209, 241.36853027],
[265.65979004, 256.28912354, 262.18252563, ...,
242.03181458, 256.16757202, 252.44122314],
[262.38916016, 255.72747803, 261.23059082, ...,
240.83563232, 255.41094971, 249.62471008],
...,
[259.15814209, 253.09197998, 257.72174072, ...,
242.23876953, 250.72680664, 247.16642761],
[271.22662354, 261.91204834, 265.33398438, ...,
248.26651001, 262.05389404, 261.33700562],
[257.16134644, 254.7543335, 259.08300781, ..., 241.59848022,
248.62597656, 243.22348022]], dtype=float32)

When an operation such as :code:`np.dot(x, x)` is executed, JAX does not wait
for the operation to complete before returning control to the Python program.
Instead, JAX returns a :class:`DeviceArray` value, which is a future, i.e., a
value that will be produced in the future on an accelerator device but isn't
necessarily available immediately. We can inspect the shape or type of a
:class:`DeviceArray` without waiting for the computation that produced it to
complete, and we can even pass it to another JAX computation, as we do with the
addition operation here. Only if we actually inspect the value of the array from
the host, for example by printing it or by converting it into a plain old
:class:`numpy.ndarray` will JAX force the Python code to wait for the
computation to complete.

Asynchronous dispatch is useful since it allows Python code to "run ahead" of
an accelerator device, keeping Python code out of the critical path.
Provided the Python code enqueues work on the device faster than it can be
executed, and provided that the Python code does not actually need to inspect
the output of a computation on the host, then a Python program can enqueue
arbitrary amounts of work and avoid having the accelerator wait.

Asynchronous dispatch has a slightly surprising consequence for microbenchmarks.

>>> %time np.dot(x, x)
CPU times: user 267 µs, sys: 93 µs, total: 360 µs
Wall time: 269 µs
DeviceArray([[255.01972961, 246.64862061, 254.13371277, ...,
233.67948914, 247.68939209, 238.36853027],
[262.65979004, 253.28910828, 259.18252563, ...,
239.03181458, 253.16757202, 249.44122314],
[259.38916016, 252.72747803, 258.23059082, ...,
237.83563232, 252.41094971, 246.62471008],
...,
[256.15814209, 250.09197998, 254.72172546, ...,
239.23876953, 247.72680664, 244.16642761],
[268.22662354, 258.91204834, 262.33398438, ...,
245.26651001, 259.05389404, 258.33700562],
[254.16134644, 251.7543335, 256.08300781, ..., 238.59848022,
245.62597656, 240.22348022]], dtype=float32)

269µs is a surprisingly small time for a 1000x1000 matrix multiplication on CPU!
However it turns out that asynchronous dispatch is misleading us and we are not
timing the execution of the matrix multiplication, only the time to dispatch
the work. To measure the true cost of the operation we must either read the
value on the host (e.g., convert it to a plain old host-side numpy array), or
use the :func:`block_until_ready` method on a :class:`DeviceArray` value to wait
for the computation that produced it to complete.

>>> %time onp.asarray(np.dot(x, x))
CPU times: user 61.1 ms, sys: 0 ns, total: 61.1 ms
Wall time: 8.09 ms
Out[16]:
array([[255.01973, 246.64862, 254.13371, ..., 233.67949, 247.68939,
238.36853],
[262.6598 , 253.28911, 259.18253, ..., 239.03181, 253.16757,
249.44122],
[259.38916, 252.72748, 258.2306 , ..., 237.83563, 252.41095,
246.62471],
...,
[256.15814, 250.09198, 254.72173, ..., 239.23877, 247.7268 ,
244.16643],
[268.22662, 258.91205, 262.33398, ..., 245.26651, 259.0539 ,
258.337 ],
[254.16135, 251.75433, 256.083 , ..., 238.59848, 245.62598,
240.22348]], dtype=float32)
>>> %time np.dot(x, x).block_until_ready()
CPU times: user 50.3 ms, sys: 928 µs, total: 51.2 ms
Wall time: 4.92 ms
DeviceArray([[255.01972961, 246.64862061, 254.13371277, ...,
233.67948914, 247.68939209, 238.36853027],
[262.65979004, 253.28910828, 259.18252563, ...,
239.03181458, 253.16757202, 249.44122314],
[259.38916016, 252.72747803, 258.23059082, ...,
237.83563232, 252.41094971, 246.62471008],
...,
[256.15814209, 250.09197998, 254.72172546, ...,
239.23876953, 247.72680664, 244.16642761],
[268.22662354, 258.91204834, 262.33398438, ...,
245.26651001, 259.05389404, 258.33700562],
[254.16134644, 251.7543335, 256.08300781, ..., 238.59848022,
245.62597656, 240.22348022]], dtype=float32)

Blocking without transferring the result back to Python is usually faster, and
is often the best choice when writing microbenchmarks of computation times.
10 changes: 8 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@ JAX reference documentation
===============================

Composable transformations of Python+NumPy programs: differentiate, vectorize,
JIT to GPU/TPU, and more
JIT to GPU/TPU, and more.

For an introduction to JAX, start at the
`JAX GitHub page <https://github.com/google/jax>`_.

.. toctree::
:maxdepth: 1
:caption: Notes

async_dispatch

.. toctree::
:maxdepth: 3
:caption: Contents:
:caption: Contents

jax

Expand Down