Skip to content

Commit bcbbaec

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Include a sequence number when dumping jaxprs. If two functions have the same name we may overwrite jaxprs with later ones.
PiperOrigin-RevId: 817249274
1 parent 8ec2ccb commit bcbbaec

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

jax/_src/jaxpr_util.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,9 @@ def eqns_using_var(jaxpr: core.Jaxpr, invar: core.Var) -> Iterator[core.JaxprEqn
276276
# if the previous condition fails, there is no deeper jaxpr to explore =(
277277
yield eqn
278278

279+
280+
_jaxpr_id_counter = itertools.count()
281+
279282
def maybe_dump_jaxpr_to_file(
280283
fun_name: str, jaxpr: core.Jaxpr
281284
) -> str | None:
@@ -296,16 +299,17 @@ def maybe_dump_jaxpr_to_file(
296299
modes = config.jax_dump_ir_modes.value.split(",")
297300
if "jaxpr" not in modes and "eqn_count_pprof" not in modes:
298301
return None
302+
id = next(_jaxpr_id_counter)
299303
if "jaxpr" in modes:
300304
logging.log(
301305
logging.INFO, "Dumping jaxpr for %s to %s.", fun_name, out_dir
302306
)
303-
jaxpr_path = out_dir / f"{fun_name}.jaxpr.txt"
307+
jaxpr_path = out_dir / f"jax_{id:06d}_{fun_name}.jaxpr.txt"
304308
jaxpr_path.write_text(jaxpr.pretty_print())
305309
if "eqn_count_pprof" in modes:
306310
logging.log(
307311
logging.INFO, "Dumping eqn count pprof for %s to %s.", fun_name, out_dir
308312
)
309-
eqn_prof_path = out_dir / f"{fun_name}.eqn_count_pprof"
313+
eqn_prof_path = out_dir / f"jax_{id:06d}_{fun_name}.eqn_count_pprof"
310314
eqn_prof_path.write_bytes(pprof_equation_profile(jaxpr))
311315
return fun_name

0 commit comments

Comments
 (0)