Skip to content

Commit 0e352de

Browse files
committed
Fix a concurrency problem in jax_explain_cache_misses.
Fixes #30163
1 parent 36a716e commit 0e352de

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

jax/_src/pjit.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1293,8 +1293,16 @@ def explain_tracing_cache_miss(
12931293

12941294
p(f" for {func_name}{src_info}")
12951295

1296+
# Do *not* remove the list() around the call to keys(). The cache may be
1297+
# updated concurrently by other threads, and we need to perform the iteration
1298+
# over the dictionary keys in a way that is concurrency safe. Here we are
1299+
# relying on an implementation behavior of CPython wherein the particular list
1300+
# constructor used here acts atomically.
1301+
# See https://github.com/jax-ml/jax/issues/30163
1302+
cache_keys = list(cache.keys())
1303+
12961304
diffs = [diff_tracing_cache_keys(key, ok, debug_info)
1297-
for ok in cache.keys() if key != ok]
1305+
for ok in cache_keys if key != ok]
12981306
assert diffs, "we must find some diffs if key differs from all cache keys"
12991307
min_diff = min(diffs, key=lambda v: v[1])
13001308
smallest_diffs: Sequence[Sequence[str]] # the diffs for the closest keys

tests/api_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import re
3434
import subprocess
3535
import sys
36+
import threading
3637
import traceback
3738
import types
3839
from typing import NamedTuple
@@ -4925,6 +4926,32 @@ def test_cache_miss_explanations_no_source_info(self):
49254926
with config.explain_cache_misses(True):
49264927
jax.jit(operator.add)(42, 24)
49274928

4929+
def test_cache_miss_explanations_are_thread_safe(self):
4930+
@jax.jit
4931+
def f(i):
4932+
return jnp.sum(i)
4933+
4934+
saw_exception = False
4935+
4936+
def thread(i0):
4937+
nonlocal saw_exception
4938+
try:
4939+
for i in range(i0, 100, 10):
4940+
if saw_exception:
4941+
break
4942+
with config.explain_cache_misses(True):
4943+
f(jnp.zeros(i))
4944+
except Exception:
4945+
saw_exception = True
4946+
raise
4947+
4948+
t = [threading.Thread(target=thread, args=(i,)) for i in range(10)]
4949+
for i in t:
4950+
i.start()
4951+
for i in t:
4952+
i.join()
4953+
self.assertFalse(saw_exception)
4954+
49284955
@parameterized.named_parameters([
49294956
{"testcase_name": f"{np.dtype(dtype)}", "dtype": dtype}
49304957
for dtype in jtu.dtypes.custom_floats])

0 commit comments

Comments
 (0)