Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/debugging/print_breakpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ Why? Under the hood, the compiler gets a functional representation of the staged

To preserve the original order of `jax.debug.print`s as written in your Python function, you can use `jax.debug.print(..., ordered=True)`, which will ensure the relative order of prints is preserved. But using `ordered=True` will raise an error under `jax.pmap` and other JAX transformations involving parallelism, since ordering can't be guaranteed under parallel execution.

#### Computation perturbation

Adding `jax.debug.print` or `jax.debug.breakpoint` statements will change the computation that XLA is asked to compile. This can potentially result in numeric discrepancies compared to the same code without debug statements because XLA might perform different operation fusions during compilation. Keep this in mind when debugging numerical issues, as the act of adding debug statements might affect the behavior you're trying to investigate.

#### Asynchronous callbacks

Depending on the backend, `jax.debug.print`s may happen asynchronously, i.e. not in your main program thread. This means that values could be printed to your screen even after your JAX function has returned a value.
Expand Down