-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Description
Hi, jax_explain_cache_misses is not thread safe with freethreading, see the following for example. I have a bit of more fundamental question, is the following snipped supposed to work or should each thread make their own copy of the jitted f?
import time
import jax
import jax.numpy as jnp
from threading import Thread
jax.config.update("jax_explain_cache_misses", True)
@jax.jit
def f(i):
return jnp.sum(i)
N = 10
ttd = False
def thread(i0):
global ttd
try:
for i in range(i0, 100000, N):
if ttd:
break
f(jnp.zeros(i))
time.sleep(1)
except Exception:
ttd = True
raise
t = [Thread(target=thread, args=(i,)) for i in range(N)]
for i in t:
i.start()
for i in t:
i.join()
File "/home/user/project/.pixi/envs/dev/lib/python3.13t/threading.py", line 1043, in _bootstrap_inner
self.run()
~~~~~~~~^^
File "/home/user/project/.pixi/envs/dev/lib/python3.13t/threading.py", line 994, in run
self._target(*self._args, **self._kwargs)
~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/project/jax_compile_err.py", line 24, in thread
f(jnp.zeros(i))
~^^^^^^^^^^^^^^
File "/home/user/project/.pixi/envs/dev/lib/python3.13t/site-packages/jax/_src/traceback_util.py", line 182, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/user/project/.pixi/envs/dev/lib/python3.13t/site-packages/jax/_src/pjit.py", line 292, in cache_miss
executable, pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/project/.pixi/envs/dev/lib/python3.13t/site-packages/jax/_src/pjit.py", line 139, in _python_pjit_helper
p, args_flat = _infer_params(fun, jit_info, args, kwargs)
~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/project/.pixi/envs/dev/lib/python3.13t/site-packages/jax/_src/pjit.py", line 686, in _infer_params
return _infer_params_internal(fun, ji, args, kwargs)
File "/home/user/project/.pixi/envs/dev/lib/python3.13t/site-packages/jax/_src/pjit.py", line 710, in _infer_params_internal
p, args_flat = _infer_params_impl(
~~~~~~~~~~~~~~~~~~^
fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/project/.pixi/envs/dev/lib/python3.13t/site-packages/jax/_src/pjit.py", line 606, in _infer_params_impl
jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
~~~~~~~~~~~~~~~~~~^
flat_fun, in_type, attr_token, IgnoreKey(ji.inline))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/project/.pixi/envs/dev/lib/python3.13t/site-packages/jax/_src/linear_util.py", line 473, in memoized_fun
explain(fun, cache is new_cache, cache, key, time.time() - start) # type: ignore
~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/user/project/.pixi/envs/dev/lib/python3.13t/site-packages/jax/_src/pjit.py", line 1370, in explain_tracing_cache_miss
for ok in cache.keys() if key != ok]
~~~~~~~~~~^^
RuntimeError: dictionary changed size during iteration
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.6.2
jaxlib: 0.6.2
numpy: 2.3.1
python: 3.13.5 experimental free-threading build | packaged by conda-forge | (main, Jun 16 2025, 08:33:00) [GCC 13.3.0]
device info: NVIDIA GeForce RTX 4090-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='XXX', release='6.15.4-gentoo-dist', version='#1 SMP PREEMPT_DYNAMIC Sat Jun 28 01:32:27 EDT 2025', machine='x86_64')
$ nvidia-smi
Fri Jul 11 18:50:57 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.64.03 Driver Version: 575.64.03 CUDA Version: 12.9 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 4090 Off | 00000000:01:00.0 Off | Off |
| 0% 42C P8 20W / 450W | 541MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 1278 G /usr/bin/X 105MiB |
| 0 N/A N/A 1350 G /usr/bin/gnome-shell 17MiB |
| 0 N/A N/A 100138 C ....pixi/envs/dev/bin/python3.13 390MiB |
+-----------------------------------------------------------------------------------------+
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working