Skip to content

jax_explain_cache_misses is not thread safe #30163

@SobhanMP

Description

@SobhanMP

Description

Hi, jax_explain_cache_misses is not thread safe with freethreading, see the following for example. I have a bit of more fundamental question, is the following snipped supposed to work or should each thread make their own copy of the jitted f?

import time
import jax
import jax.numpy as jnp
from threading import Thread

jax.config.update("jax_explain_cache_misses", True)


@jax.jit
def f(i):
    return jnp.sum(i)


N = 10
ttd = False


def thread(i0):
    global ttd
    try:
        for i in range(i0, 100000, N):
            if ttd:
                break
            f(jnp.zeros(i))
            time.sleep(1)
    except Exception:
        ttd = True
        raise


t = [Thread(target=thread, args=(i,)) for i in range(N)]
for i in t:
    i.start()

for i in t:
    i.join()
  File "/home/user/project/.pixi/envs/dev/lib/python3.13t/threading.py", line 1043, in _bootstrap_inner
    self.run()
    ~~~~~~~~^^
  File "/home/user/project/.pixi/envs/dev/lib/python3.13t/threading.py", line 994, in run
    self._target(*self._args, **self._kwargs)
    ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/project/jax_compile_err.py", line 24, in thread
    f(jnp.zeros(i))
    ~^^^^^^^^^^^^^^
  File "/home/user/project/.pixi/envs/dev/lib/python3.13t/site-packages/jax/_src/traceback_util.py", line 182, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/user/project/.pixi/envs/dev/lib/python3.13t/site-packages/jax/_src/pjit.py", line 292, in cache_miss
    executable, pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
                                 ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/project/.pixi/envs/dev/lib/python3.13t/site-packages/jax/_src/pjit.py", line 139, in _python_pjit_helper
    p, args_flat = _infer_params(fun, jit_info, args, kwargs)
                   ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/project/.pixi/envs/dev/lib/python3.13t/site-packages/jax/_src/pjit.py", line 686, in _infer_params
    return _infer_params_internal(fun, ji, args, kwargs)
  File "/home/user/project/.pixi/envs/dev/lib/python3.13t/site-packages/jax/_src/pjit.py", line 710, in _infer_params_internal
    p, args_flat = _infer_params_impl(
                   ~~~~~~~~~~~~~~~~~~^
        fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/project/.pixi/envs/dev/lib/python3.13t/site-packages/jax/_src/pjit.py", line 606, in _infer_params_impl
    jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
                                              ~~~~~~~~~~~~~~~~~~^
        flat_fun, in_type, attr_token, IgnoreKey(ji.inline))
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/project/.pixi/envs/dev/lib/python3.13t/site-packages/jax/_src/linear_util.py", line 473, in memoized_fun
    explain(fun, cache is new_cache, cache, key, time.time() - start)  # type: ignore
    ~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/project/.pixi/envs/dev/lib/python3.13t/site-packages/jax/_src/pjit.py", line 1370, in explain_tracing_cache_miss
    for ok in cache.keys() if key != ok]
              ~~~~~~~~~~^^
RuntimeError: dictionary changed size during iteration

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.6.2
jaxlib: 0.6.2
numpy:  2.3.1
python: 3.13.5 experimental free-threading build | packaged by conda-forge | (main, Jun 16 2025, 08:33:00) [GCC 13.3.0]
device info: NVIDIA GeForce RTX 4090-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='XXX', release='6.15.4-gentoo-dist', version='#1 SMP PREEMPT_DYNAMIC Sat Jun 28 01:32:27 EDT 2025', machine='x86_64')

$ nvidia-smi
Fri Jul 11 18:50:57 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.64.03              Driver Version: 575.64.03      CUDA Version: 12.9     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4090        Off |   00000000:01:00.0 Off |                  Off |
|  0%   42C    P8             20W /  450W |     541MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A            1278      G   /usr/bin/X                              105MiB |
|    0   N/A  N/A            1350      G   /usr/bin/gnome-shell                     17MiB |
|    0   N/A  N/A          100138      C   ....pixi/envs/dev/bin/python3.13        390MiB |
+-----------------------------------------------------------------------------------------+

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions