Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions docs/debugging/xla_metadata.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

**Summary:** `set_xla_metadata` allows you to attach metadata to operations in your JAX code. This metadata is passed down to the XLA compiler as `frontend_attributes` and can be used to enable compiler-level debugging tools, such as the XLA-TPU debugger.

You can use it in two ways:
You can use it in three ways:

1. Tag an individual operation by wrapping its output value
2. Tag a block of operations using a context manager
3. Tag all operations in a function using a decorator

**Warning:** `set_xla_metadata` is an experimental feature and its API is subject to change.

Expand Down Expand Up @@ -42,7 +43,7 @@ ENTRY main.5 {
ROOT mul.4 = f32[] multiply(sin.2, cos.3), frontend_attributes={breakpoint="true"}
}
```
## Tagging a Block of Code with a Context Manager
## Tagging a Block of Code with a Context Manager or Decorator
If you want to apply the same metadata to a larger section of code, you can use `set_xla_metadata` as a context manager. All JAX operations within the `with` block will have the specified metadata attached.

```python
Expand All @@ -69,6 +70,25 @@ ENTRY main.5 {
}
```

If you want to tag all operations in a function, you can also use `set_xla_metadata` as a decorator:

```python
import jax
import jax.numpy as jnp
from jax.experimental.xla_metadata import set_xla_metadata

# Tagging with a decorator
@set_xla_metadata(_xla_log=True)
@jax.jit
def decorator_tagging(x):
y = jnp.sin(x)
z = jnp.cos(y)
return y * z

print(decorator_tagging.lower(1.0).as_text("hlo"))
```
This will result in the same HLO as above.

# Interaction with JAX Transformations
`set_xla_metadata` utilizes either a `XlaMetadataContextManager` or JAX `primitive` depending on use-case and is compatible with JAX's transformations like `jit`, `vmap`, and `grad`.
* **`vmap`**: When you `vmap` a function containing `set_xla_metadata`, the metadata will be applied to all of the relevant batched operations.
Expand Down
1 change: 1 addition & 0 deletions jax/_src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,7 @@ pytype_strict_library(
":typing",
":util",
":xla_bridge",
":xla_metadata_lib",
"//jax/_src/lib",
] + py_deps("numpy"),
)
Expand Down
18 changes: 17 additions & 1 deletion jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,15 @@
from jax._src import effects as effects_lib
from jax._src import frozen_dict
from jax._src import hashable_array
from jax._src import literals
from jax._src import jaxpr_util
from jax._src import linear_util as lu
from jax._src import literals
from jax._src import path
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src import xla_metadata_lib
from jax._src.interpreters import partial_eval as pe
from jax._src.layout import AutoLayout, Layout
from jax._src.lib import _jax
Expand Down Expand Up @@ -1788,6 +1789,21 @@ def lower_jaxpr_to_fun(
for attrs in const_arg_attrs:
attrs["jax.const"] = ir.BoolAttr.get(True)

xla_metadata = xla_metadata_lib.current_xla_metadata()
if xla_metadata:
ctx_attributes = {
k: ir.StringAttr.get(str(v).lower()) for k, v in xla_metadata.items()
}
for i in range(num_dim_vars + num_tokens, len(flat_input_types)):
attrs = arg_attrs[i]
existing_attributes = {}
if "mhlo.frontend_attributes" in attrs:
for a in attrs["mhlo.frontend_attributes"].attr:
existing_attributes[a.name] = a.attr
attrs["mhlo.frontend_attributes"] = ir.DictAttr.get(
existing_attributes | ctx_attributes
)

func_op.arg_attrs = ir.ArrayAttr.get(
[ir.DictAttr.get(attrs) for attrs in arg_attrs])
# End populate arg_attrs
Expand Down
55 changes: 50 additions & 5 deletions jax/_src/xla_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
from functools import partial, wraps
from typing import Any

from jax._src import config
Expand All @@ -27,6 +27,40 @@
config_ext = xla_client._xla.config


class _XlaMetadataWrapper:
"""A wrapper class to allow XlaMetadataContextManager to be used as a decorator.

When XlaMetadataContextManager is used as a decorator on a function `f`, it
returns an instance of this class. This wrapper ensures that when `f` is
called, it runs within the metadata context. It also forwards attribute
access to `f` via `__getattr__`, and if an attribute of `f` is callable (e.g.,
the `.lower()` method of a jitted function), it wraps that attribute so it
too runs within the metadata context when called. This allows decorated
functions to be used seamlessly with JAX transformations like `jax.jit`.
"""

def __init__(self, f, ctx):
self._f = f
self._ctx = ctx
wraps(f)(self)

def __call__(self, *args, **kwargs):
with self._ctx:
return self._f(*args, **kwargs)

def __getattr__(self, name):
attr = getattr(self._f, name)
if not callable(attr):
return attr

@wraps(attr)
def wrapper(*args, **kwargs):
with self._ctx:
return attr(*args, **kwargs)

return wrapper


class XlaMetadataContextManager:
__slots__ = ["prev", "updates"]

Expand All @@ -47,39 +81,50 @@ def __exit__(self, exc_type, exc_value, traceback):
return
config.xla_metadata_context_manager.set_local(self.prev)

def __call__(self, f):
return _XlaMetadataWrapper(f, self)


def set_xla_metadata(x=None, **kwargs):
if x is None:
return XlaMetadataContextManager(kwargs)
else:
hashable_metadata = tuple(sorted(kwargs.items()))
return tree_util.tree_map(
lambda v: xla_metadata_value_p.bind(v, xla_metadata_kvs=hashable_metadata),
lambda v: xla_metadata_value_p.bind(
v, xla_metadata_kvs=hashable_metadata
),
x,
)


# `xla_metadata_value_p` is an identity primitive for attaching frontend_attributes
# to the primitive's producing (parent/owner) op.
xla_metadata_value_p = core.Primitive("xla_metadata_value")
xla_metadata_value_p.def_impl(partial(dispatch.apply_primitive, xla_metadata_value_p))
xla_metadata_value_p.def_impl(
partial(dispatch.apply_primitive, xla_metadata_value_p)
)
xla_metadata_value_p.def_abstract_eval(lambda aval, *, xla_metadata_kvs: aval)
batching.defvectorized(xla_metadata_value_p)
# TODO(nbasile): Implement tagging gradient ops with metadata.
ad.deflinear2(xla_metadata_value_p, lambda ct, _: (ct,))


def _xla_metadata_value_lowering_rule(
ctx: mlir.LoweringRuleContext, val: ir.Value, *, xla_metadata_kvs):
ctx: mlir.LoweringRuleContext, val: ir.Value, *, xla_metadata_kvs
):
xla_metadata = dict(xla_metadata_kvs)
op_to_attach_metadata = _target_op_to_attach_metadata(val)
if op_to_attach_metadata is not None:
_attach_xla_metadata_to_op(xla_metadata, op_to_attach_metadata)
return [val]


# If we leave `cacheable=True`, when we are in the lowering rule, the `val.owner`
# becomes a cached `FuncOp`. FuncOp.owners are Blocks, which we can't tag.
mlir.register_lowering(
xla_metadata_value_p, _xla_metadata_value_lowering_rule, cacheable=False)
xla_metadata_value_p, _xla_metadata_value_lowering_rule, cacheable=False
)


def _target_op_to_attach_metadata(value_mlir: ir.Value) -> ir.Operation | None:
Expand Down
60 changes: 57 additions & 3 deletions tests/xla_metadata_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,61 @@ def f(a, b):
f_lowered_text = f.lower(1.0, 2.0).as_text()
self.assertIn('mhlo.frontend_attributes = {a = "10"}', f_lowered_text)

def test_decorator(self):
@set_xla_metadata(a="b")
@jax.jit
def f(a, b):
return a + b

f_jaxpr = jax.make_jaxpr(f)(1, 2)
eqns = f_jaxpr.eqns
for eq in eqns[1:]:
self.assertDictEqual(eq.ctx.attributes, {"a": "b"})

f_lowered_text = f.lower(1.0, 2.0).as_text()
self.assertIn('mhlo.frontend_attributes = {a = "b"}', f_lowered_text)
self.assertRegex(
f_lowered_text, r'%arg0:.*mhlo\.frontend_attributes = \{.*a = "b".*\}'
)
self.assertRegex(
f_lowered_text, r'%arg1:.*mhlo\.frontend_attributes = \{.*a = "b".*\}'
)

def test_decorator_and_context_manager_nested(self):
@set_xla_metadata(a="b")
@jax.jit
def f(a, b):
with set_xla_metadata(c="d"):
return a + b

f_lowered_text = f.lower(1.0, 2.0).as_text()
self.assertIn(
'mhlo.frontend_attributes = {a = "b", c = "d"}',
f_lowered_text,
)
self.assertRegex(
f_lowered_text, r'%arg0:.*mhlo\.frontend_attributes = \{.*a = "b".*\}'
)
self.assertRegex(
f_lowered_text, r'%arg1:.*mhlo\.frontend_attributes = \{.*a = "b".*\}'
)

def test_f_nonjitted(self):
def f_add(a, b):
return lax.add(a, b)

arg1 = jnp.arange(2)
with set_xla_metadata(a="b"):
f_lowered_text = jax.jit(f_add).lower(arg1, arg1).as_text()
self.assertIn(
'mhlo.frontend_attributes = {a = "b"}',
jax.jit(f_add).lower(arg1, arg1).as_text(),
f_lowered_text,
)
self.assertRegex(
f_lowered_text, r'%arg0:.*mhlo\.frontend_attributes = \{.*a = "b".*\}'
)
self.assertRegex(
f_lowered_text, r'%arg1:.*mhlo\.frontend_attributes = \{.*a = "b".*\}'
)

def test_f_attributes_overwrite(self):
Expand Down Expand Up @@ -131,6 +177,14 @@ def f(a, b):
'mhlo.frontend_attributes = {key1 = "val1", key2 = "val2"}',
f_lowered_text,
)
self.assertRegex(
f_lowered_text,
r'%arg0:.*mhlo\.frontend_attributes = \{.*key1 = "val1".*\}',
)
self.assertRegex(
f_lowered_text,
r'%arg1:.*mhlo\.frontend_attributes = \{.*key1 = "val1".*\}',
)

def test_attr_caching_jit(self):
@jax.jit
Expand Down Expand Up @@ -276,7 +330,7 @@ def f(x, y):
f_jaxpr = jax.make_jaxpr(f)(1.0, 2.0)
eqns = f_jaxpr.eqns
for eq in eqns[1:]:
self.assertDictEqual(eq.ctx.attributes, {"a": "b"})
self.assertDictEqual(eq.ctx.xla_metadata, {"a": "b"})

self.assertIn(
'mhlo.frontend_attributes = {a = "b"}', f.lower(1.0, 2.).as_text()
Expand Down Expand Up @@ -308,7 +362,7 @@ def f(dct, x):
f_jaxpr = jax.make_jaxpr(f_vmap)(dct, 1.0)
eqns = f_jaxpr.eqns
for eq in eqns[1:]:
self.assertDictEqual(eq.ctx.attributes, {"a": "d"})
self.assertDictEqual(eq.ctx.xla_metadata, {"a": "d"})
@jax.jit
def f2(x, y):
with set_xla_metadata(a="b"):
Expand Down
Loading