File tree Expand file tree Collapse file tree 2 files changed +36
-1
lines changed Expand file tree Collapse file tree 2 files changed +36
-1
lines changed Original file line number Diff line number Diff line change @@ -1293,8 +1293,16 @@ def explain_tracing_cache_miss(
1293
1293
1294
1294
p (f" for { func_name } { src_info } " )
1295
1295
1296
+ # Do *not* remove the list() around the call to keys(). The cache may be
1297
+ # updated concurrently by other threads, and we need to perform the iteration
1298
+ # over the dictionary keys in a way that is concurrency safe. Here we are
1299
+ # relying on an implementation behavior of CPython wherein the particular list
1300
+ # constructor used here acts atomically.
1301
+ # See https://github.com/jax-ml/jax/issues/30163
1302
+ cache_keys = list (cache .keys ())
1303
+
1296
1304
diffs = [diff_tracing_cache_keys (key , ok , debug_info )
1297
- for ok in cache . keys () if key != ok ]
1305
+ for ok in cache_keys if key != ok ]
1298
1306
assert diffs , "we must find some diffs if key differs from all cache keys"
1299
1307
min_diff = min (diffs , key = lambda v : v [1 ])
1300
1308
smallest_diffs : Sequence [Sequence [str ]] # the diffs for the closest keys
Original file line number Diff line number Diff line change 33
33
import re
34
34
import subprocess
35
35
import sys
36
+ import threading
36
37
import traceback
37
38
import types
38
39
from typing import NamedTuple
@@ -4925,6 +4926,32 @@ def test_cache_miss_explanations_no_source_info(self):
4925
4926
with config .explain_cache_misses (True ):
4926
4927
jax .jit (operator .add )(42 , 24 )
4927
4928
4929
+ def test_cache_miss_explanations_are_thread_safe (self ):
4930
+ @jax .jit
4931
+ def f (i ):
4932
+ return jnp .sum (i )
4933
+
4934
+ saw_exception = False
4935
+
4936
+ def thread (i0 ):
4937
+ nonlocal saw_exception
4938
+ try :
4939
+ for i in range (i0 , 100 , 10 ):
4940
+ if saw_exception :
4941
+ break
4942
+ with config .explain_cache_misses (True ):
4943
+ f (jnp .zeros (i ))
4944
+ except Exception :
4945
+ saw_exception = True
4946
+ raise
4947
+
4948
+ t = [threading .Thread (target = thread , args = (i ,)) for i in range (10 )]
4949
+ for i in t :
4950
+ i .start ()
4951
+ for i in t :
4952
+ i .join ()
4953
+ self .assertFalse (saw_exception )
4954
+
4928
4955
@parameterized .named_parameters ([
4929
4956
{"testcase_name" : f"{ np .dtype (dtype )} " , "dtype" : dtype }
4930
4957
for dtype in jtu .dtypes .custom_floats ])
You can’t perform that action at this time.
0 commit comments