Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
6f50677
hilbert transform
joglekara Mar 21, 2023
2bdf6c9
newline
joglekara Mar 21, 2023
cf5ddb0
dtype errors
joglekara Mar 24, 2023
e120f13
passing tests
joglekara Mar 24, 2023
0fd7920
linters
joglekara Mar 24, 2023
abf8947
Optional int
joglekara Mar 24, 2023
cbc25dc
Raise a better error message when there is a device assignment mismat…
yashk2810 Mar 21, 2023
ab91657
Fix inspect_array_sharding with grad.
nouiz Mar 21, 2023
c092312
Relax type annotations on lax slicing functions.
hawkinsp Mar 21, 2023
8e2e2f5
[jax2tf] Clean up the jax2tf sharding_tests.
gnecula Mar 21, 2023
41c1d93
Remove the config.jax_array and jax_jit_pjit_api_merge flag usage sin…
yashk2810 Mar 21, 2023
8717494
Document ShapeDtypeStruct
jakevdp Mar 21, 2023
fa19118
[jax2tf] Add back_compat test for LuDecomposition
atondwal Mar 21, 2023
ddab64a
LuDecomposition moved from fallback path to custom_call
atondwal Mar 21, 2023
48db6c8
[PJRT C API] Add parsing PJRT client create options from json file.
Mar 21, 2023
ffc8a34
`custom_vjp` symbolic zeros support
froystig Mar 14, 2023
ce3f534
[jax2tf] Fix grad of pjit in native lowering.
gnecula Mar 21, 2023
4aa8ae9
Fix mypy failures in jax2tf.
hawkinsp Mar 22, 2023
f4a40dc
add wald random generator
JiaYaobo Mar 21, 2023
92e79b3
Revert: `custom_vjp` symbolic zeros support
hawkinsp Mar 22, 2023
fcac7b4
jax.random: remove scale from wald function
jakevdp Mar 22, 2023
78488f0
Improve handling of dynamic shapes in jax native serialization
GleasonK Mar 22, 2023
c5ba4d3
make mlir arg and result names work with pmap
mattjj Mar 18, 2023
499372d
DOC: add formulae for distributions in jax.random
jakevdp Mar 22, 2023
aa46778
WAR the dependency issue in the nightly CI container.
nouiz Mar 22, 2023
f9d73cb
Add print statement to help debug spurious test failure
sharadmv Mar 22, 2023
3039951
add experimental jax_log_checkpoint_residuals option
mattjj Mar 22, 2023
4a27af3
Redefine `compile_and_serialize` as `serialize(lowered.compile())`.
pschuh Mar 23, 2023
6fee63c
enable pjit recursive typechecking
mattjj Mar 22, 2023
dd9f178
fix typo: "one of more" -> "one or more"
a-googler Mar 23, 2023
b0e0a94
Fix isinstance(k, PRNGKeyArray) on PRNGKeyArray subclasses
Conchylicultor Mar 23, 2023
a9b4310
[jax2tf] Fix tests broken by upgrade of XlaCallModule
gnecula Mar 23, 2023
8f6e3c4
[XLA:Python] Change JAX and the XLA Python extension to get NumPy bfl…
hawkinsp Mar 23, 2023
c68c3d3
jnp.mean: fix incorrect return value for large arrays
jakevdp Mar 21, 2023
75fcc3a
[XLA:Python] Allow passing ExecutableBuildOptions to outfeed receiver.
hawkinsp Mar 23, 2023
0e2cf94
Add missing file
nouiz Mar 23, 2023
8f4b8a0
Fix loc and scale parameters in scipy.logistic. Add CDF and SF for se…
b0nce Mar 12, 2023
5b30f8e
Move jax._src.typing into a separate Bazel target.
hawkinsp Mar 23, 2023
7a326b3
Deprecated xla_call_p since it has been replaced with pjit.pjit_p
yashk2810 Mar 23, 2023
6ee6598
Fix mypy issue in jax/experimental/jet.py
jakevdp Mar 23, 2023
180f12e
add trailing-whitespace pre-commit hook
cgarciae Mar 23, 2023
fdea9e6
[jax2tf] Minor addition to the documentation
gnecula Mar 23, 2023
1e356cf
internal: refactor array methods into separate private submodule
jakevdp Mar 23, 2023
23b0743
Refactor special functions into their own module.
atondwal Mar 23, 2023
071d9b9
improve scan error messages
mattjj Mar 23, 2023
175cd37
[jax2tf] Create a jax_export library with JAX-only pieces for native …
gnecula Mar 23, 2023
4994472
Add padding option "SAME_LOWER" for ticket https://github.com/google/…
maxwillzq Mar 23, 2023
32b8c42
[jax2tf] A simple failing test on TPU with native serialization
gnecula Mar 23, 2023
195f847
Migrate igamma_p off xla_fallback
atondwal Mar 23, 2023
f63a09c
lax_numpy: move quantile-based functions to reductions.py
jakevdp Mar 23, 2023
d138853
Increase minimum NumPy version to 1.21.
hawkinsp Feb 6, 2023
b6b1e42
Remove PJRT C API bypass.
skye Mar 24, 2023
fd36ed6
[dynamic-shapes] make dynamic shape staging-to-jaxpr work with pjit
mattjj Mar 23, 2023
10aeadb
fix jax.Array.round()
mattjj Mar 24, 2023
9d60043
[jaxlib] fix build w/ depenency on stablehlo_serialization
atondwal Mar 24, 2023
a88509a
Migrate igammac_p off xla_fallback path
atondwal Mar 24, 2023
c6bc3ed
Migrate igamma_grad_a_p off xla_fallback
atondwal Mar 24, 2023
aa37c74
Migrate random_gamma_grad off xla_fallback
atondwal Mar 24, 2023
2900c78
remove another dependency not currently needed.
nouiz Mar 24, 2023
f1af74f
WAR ssh timeout like:
nouiz Mar 23, 2023
14f1f60
After the SPMD bug fix, always take the _rewriting_take route for get…
yashk2810 Mar 24, 2023
b9bd60d
Guard ArrayImpl checks by xla_extension_version.
pschuh Mar 24, 2023
5c08753
[1/n] store embedded tf.graph to stablehlo.custom_call
maxwillzq Mar 24, 2023
5b43258
Delete remote TPU support.
hawkinsp Mar 24, 2023
f83a347
docs
joglekara Mar 24, 2023
3b11662
Merge branch 'google:main' into hilbert
joglekara Mar 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
make mlir arg and result names work with pmap
This is a follow-up on #15080 to restore (and indeed fix!) how pmap builds a
jaxpr with debug info (i.e. parameter names and result paths). The difference
with the machinery in #15080 is just to deal with pmap being final-style (i.e.
build the jaxpr at the last second, well after pytrees have been flattened away
and transformations have been applied), whereas the machinery for pjit in
imagine, plumbing for the former is a bit more long-range and subtle.

The main idea here is that we need to annotate and maintain debug info on the
lu.WrappedFun instance, which we first form at the api.py level, then pass
through transformations (which can either update or drop debug info), then
finally hand off to the impl rule to be traced to a jaxpr. It makes sense as an
annotation, parallel with the in_type annotation used for dynamic shapes,
because the debug info has to be updated as transformations are applied, since
they might e.g. add tangent inputs and outputs.

In more detail: with an initial-style higher-orer primitive (like pjit), a
jaxpr is formed immediately. Transformations, like autodiff, are
jaxpr-to-jaxpr, and so those transformations (like ad.jvp_jaxpr) need to return
a new jaxpr either with updated debug info or no debug info at all. (The initial
implementation in #15080 doesn't provide updated debug info in any of those
jaxpr-to-jaxpr transformation functions, so the debug info is only applied to
the jaxpr and then lowered to MLIR when the pjit as at the top level.)

For final-style, like pmap here, instead of transformations being
jaxpr-to-jaxpr, they're WrappedFun-to-WrappedFun. And so, analogously,
transformations, like ad.JVPTrace.process_map, would need to produce a
WrappedFun with updated debug info or no debug info at all. (ALso analogously
to #15080, this PR only implements enough for the debug info to be preserved
for top-level pmaps.)

This PR doens't yet delete the trace-time debug info in partial_eval.py. But
that'll happen too!
  • Loading branch information
mattjj authored and joglekara committed Mar 24, 2023
commit c5ba4d3daf3aa2b45383733afbab70c39903ceab
6 changes: 5 additions & 1 deletion jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
shaped_abstractify, _ensure_str_tuple, argnames_partial_except,
validate_argnames, validate_argnums, check_callable, resolve_argnums,
FLAGS)
debug_info, result_paths, debug_info_final, FLAGS)
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
Expand Down Expand Up @@ -1634,6 +1634,8 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
if in_devices is not None and len(in_devices) == 0:
raise ValueError("'devices' argument to pmap must be non-empty, or None.")

dbg = debug_info('pmap', fun, args, kwargs, static_broadcasted_tuple, ())

f = lu.wrap_init(fun)
if static_broadcasted_tuple:
if max(static_broadcasted_tuple) >= len(args):
Expand Down Expand Up @@ -1671,7 +1673,9 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
kws=True))
local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap")

f, res_paths = result_paths(f)
flat_fun, out_tree = flatten_fun(f, in_tree)
flat_fun = debug_info_final(flat_fun, dbg, res_paths)

if any(out_axis is None for out_axis in tree_flatten(out_axes)):
raise NotImplementedError("None out_axes in pmap are not supported yet")
Expand Down
44 changes: 24 additions & 20 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import inspect
import operator
from functools import partial
from typing import (Any, Callable, Dict, Iterable, List, NamedTuple, Optional,
Sequence, Set, Tuple, Union,)
from typing import (Any, Callable, Dict, Iterable, List, Optional, Sequence,
Set, Tuple, Union)
import warnings

import numpy as np
Expand All @@ -30,7 +30,9 @@
treedef_children, generate_key_paths, keystr)
from jax._src.tree_util import _replace_nones
from jax._src import linear_util as lu
from jax._src.util import safe_map, WrapKwArgs, Hashable, Unhashable
from jax._src.linear_util import TracingDebugInfo
from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction,
Unhashable)
from jax._src.config import flags, bool_env
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
Expand Down Expand Up @@ -585,22 +587,14 @@ def api_hook(fun, tag: str):
return fun


class TracingDebugInfo(NamedTuple):
# Packages up trace/staging-time debug info about a func and its parameters,
# formed just before staging to a jaxpr and read in trace-time error messages.
# TODO(mattjj): delete partial_eval.DebugInfo, replace all uses with this cls
traced_for: str # e.g. 'jit', 'scan', etc
func_src_info: str # e.g. f'{fun.__name__} at {filename}:{lineno}'
arg_names: Tuple[str, ...] # e.g. ('args[0]', ... )

def debug_info(traced_for: str, fun: Callable, args: Tuple[Any],
kwargs: Dict[str, Any], static_argnums: Tuple[int, ...],
static_argnames: Tuple[str, ...]) -> Optional[TracingDebugInfo]:
"""Try to build trace-time debug info for fun when applied to args/kwargs."""
src = fun_sourceinfo(fun)
arg_names = _arg_names(fun, args, kwargs, static_argnums, static_argnames)
if src is None or arg_names is None: return None
return TracingDebugInfo(traced_for, src, arg_names)
return TracingDebugInfo(traced_for, src, arg_names, None)

# TODO(mattjj): make this function internal to this module
def fun_sourceinfo(fun: Callable) -> Optional[str]:
Expand Down Expand Up @@ -635,13 +629,23 @@ def result_paths(*args, **kwargs):
yield ans, [keystr(path) for path, _ in generate_key_paths(ans)]

def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: Optional[TracingDebugInfo],
result_paths: Optional[Tuple[Optional[str], ...]]
result_paths: Optional[Tuple[Optional[str], ...]] = None,
) -> core.Jaxpr:
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
if trace_debug is not None and result_paths is not None:
debug_info = core.JaxprDebugInfo(
trace_debug.traced_for, trace_debug.func_src_info,
trace_debug.arg_names, result_paths)
else:
debug_info = None
return jaxpr.replace(debug_info=debug_info) if debug_info else jaxpr
if trace_debug is None:
return jaxpr
assert (result_paths is not None) ^ (trace_debug.result_paths is not None)
if result_paths is None:
result_paths = trace_debug.result_paths() # type: ignore
debug_info = core.JaxprDebugInfo(
trace_debug.traced_for, trace_debug.func_src_info,
trace_debug.arg_names, result_paths)
return jaxpr.replace(debug_info=debug_info)

def debug_info_final(f: lu.WrappedFun, dbg: Optional[TracingDebugInfo],
res_paths: Callable[[], Tuple[str, ...]]) -> lu.WrappedFun:
"Attach trace-time debug info and result paths lazy thunk to an lu.WrappedFun"
if dbg is None: return f
assert dbg.result_paths is None
res_paths_ = HashableFunction(res_paths, closure=())
return lu.add_debug_info(f, dbg._replace(result_paths=res_paths_))
20 changes: 7 additions & 13 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,7 @@ def stage_parallel_callable(
event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
fun, global_sharded_avals, pe.debug_info_final(fun, "pmap"))
jaxpr = api_util.jaxpr_debug_info(jaxpr, fun.debug_info)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)

assert len(out_sharded_avals) == len(pci.out_axes), (
Expand Down Expand Up @@ -1264,7 +1265,6 @@ def lower_parallel_callable(
raise ValueError("Ordered effects not supported in `pmap`.")
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
arg_names, result_names = _debug_names(jaxpr.debug_info)
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
Expand All @@ -1278,7 +1278,8 @@ def lower_parallel_callable(
replicated_args=replicated_args,
arg_shardings=_shardings_to_mlir_shardings(parts.arg_parts),
result_shardings=_shardings_to_mlir_shardings(parts.out_parts),
arg_names=arg_names, result_names=result_names)
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths)
module, keepalive, host_callbacks = (
lowering_result.module, lowering_result.keepalive,
lowering_result.host_callbacks)
Expand Down Expand Up @@ -2564,7 +2565,6 @@ def lower_sharding_computation(
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
arg_names, result_names = _debug_names(jaxpr.debug_info)
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
Expand All @@ -2579,7 +2579,8 @@ def lower_sharding_computation(
replicated_args=replicated_args,
arg_shardings=in_op_shardings,
result_shardings=out_op_shardings,
arg_names=arg_names, result_names=result_names)
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths)

module, keepalive, host_callbacks = (
lowering_result.module, lowering_result.keepalive,
Expand Down Expand Up @@ -2750,7 +2751,6 @@ def lower_mesh_computation(
closed_jaxpr.effects))
ordered_effects = list(effects.ordered_effects.filter_in(
closed_jaxpr.effects))
arg_names, result_names = _debug_names(jaxpr.debug_info)
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
Expand All @@ -2764,8 +2764,8 @@ def lower_mesh_computation(
replicated_args=replicated_args,
arg_shardings=in_partitions,
result_shardings=out_partitions,
arg_names=arg_names,
result_names=result_names)
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths)
module, keepalive, host_callbacks = (
lowering_result.module, lowering_result.keepalive,
lowering_result.host_callbacks)
Expand All @@ -2791,12 +2791,6 @@ def lower_mesh_computation(
device_assignment=list(mesh.devices.flat),
committed=True)

def _debug_names(
dbg: Optional[core.JaxprDebugInfo]
) -> Union[Tuple[None, None],
Tuple[Sequence[Optional[str]], Sequence[Optional[str]]]]:
return (None, None) if dbg is None else (dbg.arg_names, dbg.result_paths)

class MeshComputation(stages.XlaLowering):
_hlo: Optional[ir.Module]
_executable: Optional[MeshExecutable]
Expand Down
39 changes: 30 additions & 9 deletions jax/_src/linear_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def trans1(static_arg, *dynamic_args, **kwargs):
from __future__ import annotations

from functools import partial
from typing import Any, Tuple, Callable
from typing import Any, Tuple, Callable, Optional, NamedTuple
import weakref

from jax.tree_util import tree_map
Expand Down Expand Up @@ -124,14 +124,15 @@ class WrappedFun:
params: extra parameters to pass as keyword arguments to `f`, along with the
transformed keyword arguments.
"""
__slots__ = ("f", "transforms", "stores", "params", "in_type")
__slots__ = ("f", "transforms", "stores", "params", "in_type", "debug_info")

def __init__(self, f, transforms, stores, params, in_type):
def __init__(self, f, transforms, stores, params, in_type, debug_info):
self.f = f
self.transforms = transforms
self.stores = stores
self.params = params
self.in_type = in_type
self.debug_info = debug_info

@property
def __name__(self):
Expand All @@ -140,7 +141,7 @@ def __name__(self):
def wrap(self, gen, gen_static_args, out_store) -> WrappedFun:
"""Add another transform and its store."""
return WrappedFun(self.f, ((gen, gen_static_args),) + self.transforms,
(out_store,) + self.stores, self.params, None)
(out_store,) + self.stores, self.params, None, None)

def populate_stores(self, stores):
"""Copy the values from the `stores` into `self.stores`."""
Expand Down Expand Up @@ -199,11 +200,13 @@ def transform_to_str(x):
return "Wrapped function:\n" + '\n'.join(transformation_stack) + '\nCore: ' + fun_name(self.f) + '\n'

def __hash__(self):
return hash((self.f, self.transforms, self.params, self.in_type))
return hash((self.f, self.transforms, self.params, self.in_type,
self.debug_info))

def __eq__(self, other):
return (self.f == other.f and self.transforms == other.transforms and
self.params == other.params and self.in_type == other.in_type)
self.params == other.params and self.in_type == other.in_type and
self.debug_info == other.debug_info)

@curry
def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
Expand Down Expand Up @@ -231,15 +234,15 @@ def fun_name(f):
def wrap_init(f, params=None) -> WrappedFun:
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
params = () if params is None else tuple(sorted(params.items()))
return WrappedFun(f, (), (), params, None)
return WrappedFun(f, (), (), params, None, None)


def annotate(f: WrappedFun, in_type: core.InputType) -> WrappedFun:
def annotate(f: WrappedFun, in_type: Optional[core.InputType]) -> WrappedFun:
assert f.in_type is None
if in_type is None:
return f
_check_input_type(in_type)
return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type)
return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type, f.debug_info)

def _check_input_type(in_type: core.InputType) -> None:
# Check that in_type is syntactically well-formed
Expand Down Expand Up @@ -271,6 +274,24 @@ def valid_size(d) -> bool:
assert all(provided)


class TracingDebugInfo(NamedTuple):
# Packages up trace/staging-time debug info about a func and its parameters,
# formed just before staging to a jaxpr and read in trace-time error messages.
# TODO(mattjj): delete partial_eval.DebugInfo, replace all uses with this cls
traced_for: str # e.g. 'jit', 'scan', etc
func_src_info: str # e.g. f'{fun.__name__} at {filename}:{lineno}'
arg_names: Tuple[str, ...] # e.g. ('args[0]', ... )
result_paths: Optional[Callable[[], Tuple[str, ...]]]

def add_debug_info(f: WrappedFun, debug_info: Optional[TracingDebugInfo]
) -> WrappedFun:
"""Produce a new WrappedFun with debug_info attached."""
assert f.debug_info is None
if debug_info is None:
return f
return WrappedFun(f.f, f.transforms, f.stores, f.params, f.in_type, debug_info)


def cache(call: Callable):
"""Memoization decorator for functions taking a WrappedFun as first argument.

Expand Down
2 changes: 0 additions & 2 deletions tests/pmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2036,7 +2036,6 @@ def test_remat_of_pmap_policy(self, remat):
self.assertEqual(jaxpr_text.count(' cos '), 2)

def test_pmap_lower_arg_info(self):
raise SkipTest("arg info not plumbed to pmap yet") # TODO(mattjj)
def f(x, y, *args, **kwargs):
return y['hi'] + args[1] + sum(kwargs.values())

Expand All @@ -2052,7 +2051,6 @@ def f(x, y, *args, **kwargs):
self.assertIn("kwargs['w']", mhlo_str)

def test_pmap_lower_result_info(self):
raise SkipTest("arg info not plumbed to pmap yet") # TODO(mattjj)
def f(x, y, z):
return {'a': x, 'b': [y]}

Expand Down