Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@
from functools import partial
import inspect
import itertools as it
import sys
import threading
import weakref
import types
from typing import (Any, Callable, Generator, Iterable, NamedTuple, Mapping,
Optional, Sequence, Tuple, TypeVar, Union, overload, Dict,
Hashable, List)
Expand Down Expand Up @@ -452,6 +448,11 @@ def f_jitted(*args, **kwargs):

f_jitted.lower = _jit_lower(fun, static_argnums, static_argnames, device,
backend, donate_argnums, inline, keep_unused)

def clear_cache():
dispatch.xla_callable.evict_function(fun)
f_jitted.clear_cache = clear_cache

return f_jitted

def _flat_axes_specs(abstracted_axes, *args, **kwargs
Expand All @@ -477,6 +478,11 @@ class _FastpathData(NamedTuple):

_cpp_jit_cache = jax_jit.CompiledFunctionCache()


def _cpp_jit_clear_cache(self):
self._clear_cache()
dispatch.xla_callable.evict_function(self._fun)

def _cpp_jit(
fun: Callable,
*,
Expand Down Expand Up @@ -526,7 +532,7 @@ def cache_miss(*args, **kwargs):
# inspect the argument x, we actually do need to execute it and look at the
# outputs that could be tracers (if f is capturing `Tracer` by closure).
execute: Optional[functools.partial] = (
dispatch._xla_callable.most_recent_entry())
dispatch.xla_callable.most_recent_entry())
# TODO(sharadmv): Enable fast path for effectful jaxprs
# TODO(sharadmv): Clean up usage of `execute.args`
use_fastpath = (
Expand Down Expand Up @@ -592,6 +598,8 @@ def get_device_info():

f_jitted.lower = _jit_lower(fun, static_argnums, static_argnames, device,
backend, donate_argnums, inline, keep_unused)
f_jitted._fun = fun
type(f_jitted).clear_cache = _cpp_jit_clear_cache

return f_jitted

Expand Down
8 changes: 4 additions & 4 deletions jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,15 +193,15 @@ def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
donated_invars, inline, keep_unused: bool):
del inline # Only used at tracing time
arg_specs = unsafe_map(arg_spec, args)
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
keep_unused, *arg_specs)
compiled_fun = xla_callable(fun, device, backend, name, donated_invars,
keep_unused, *arg_specs)
try:
return compiled_fun(*args)
except FloatingPointError:
assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case
print("Invalid value encountered in the output of a jit-decorated function. "
"Calling the de-optimized version.")
# We want to run the wrapped function again (after _xla_callable already ran
# We want to run the wrapped function again (after xla_callable already ran
# it), but linear_util.WrappedFun instances are meant to be run only once.
# In addition to re-executing the Python code, which is usually undesirable
# but which config.jax_debug_nans is meant to opt into, we'll be
Expand Down Expand Up @@ -245,7 +245,7 @@ def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
keep_unused, *arg_specs).compile().unsafe_call

_xla_callable = lu.cache(_xla_callable_uncached)
xla_callable = lu.cache(_xla_callable_uncached)


@contextlib.contextmanager
Expand Down
4 changes: 4 additions & 0 deletions jax/linear_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,12 @@ def _most_recent_entry():
thread_local.most_recent_entry = None
return result

def _evict_function(f):
fun_caches.pop(f, None)

memoized_fun.most_recent_entry = _most_recent_entry # type: ignore
memoized_fun.cache_clear = fun_caches.clear # type: ignore
memoized_fun.evict_function = _evict_function # type: ignore

return memoized_fun

Expand Down
13 changes: 13 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,19 @@ def f(x):
python_should_be_executing = False
self.jit(f)(3)

def test_jit_cache_clear(self):
@self.jit
def f(x, y): return x + y

client = jax.devices()[0].client
num_live_initial = len(client.live_executables())
f(1, 2).block_until_ready()
num_live = len(client.live_executables())
self.assertEqual(num_live_initial + 1, num_live)
f.clear_cache()
num_live = len(client.live_executables())
self.assertEqual(num_live_initial, num_live)

def test_jit_shallow_copy(self):
def f(x):
return copy.copy(x)
Expand Down