Skip to content

Conversation

alhridoy
Copy link
Contributor

This commit adds documentation about how using jax.debug.print and jax.debug.breakpoint can perturb the computation that XLA compiles, potentially leading to numeric discrepancies.

Fixes #26370

Copy link

google-cla bot commented Feb 22, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@dfm
Copy link
Contributor

dfm commented Feb 24, 2025

Thanks! Can you fill out the CLA as requested by the bot?

@alhridoy
Copy link
Contributor Author

@dfm Thank you! I signed.

@dfm
Copy link
Contributor

dfm commented Feb 24, 2025

Thanks! Although, it looks like there might be a mismatch between the email address you have configured for git, and the one you used to sign the form. Can you click through to that failing workflow for more info, and try to get it green?

@alhridoy alhridoy force-pushed the docs-debug-print-perturbation branch from 1ec3ecb to 60a7a8b Compare February 24, 2025 19:58
@alhridoy
Copy link
Contributor Author

Thanks! Although, it looks like there might be a mismatch between the email address you have configured for git, and the one you used to sign the form. Can you click through to that failing workflow for more info, and try to get it green?

Thank you! I think it is ok now.


#### Computation perturbation

Adding `jax.debug.print` or `jax.debug.breakpoint` statements will change the computation that XLA is asked to compile. This can potentially result in numeric discrepancies compared to the same code without debug statements. Keep this in mind when debugging numerical issues, as the act of adding debug statements might affect the behavior you're trying to investigate.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be worth explicitly noting that the reason for this change is that XLA might perform different fusions when compiling.

@dfm
Copy link
Contributor

dfm commented Feb 24, 2025

Thanks! I left a small comment, but this otherwise looks good to me.

@dfm
Copy link
Contributor

dfm commented Mar 13, 2025

@alhridoy — Sorry, I didn't see that you had pushed a change because I only get notified for comments. Looks good now. Can you squash and rebase your commits onto the current main branch? Once you've done that, please ping me and I'll get this merged. Thanks!!

@alhridoy alhridoy force-pushed the docs-debug-print-perturbation branch from 0fef626 to 55ae0ef Compare August 30, 2025 20:21
This commit adds documentation about how using jax.debug.print and
jax.debug.breakpoint can perturb the computation that XLA compiles,
potentially leading to numeric discrepancies. This happens because XLA
might perform different fusions when compiling with debug statements.

Fixes jax-ml#26370
@alhridoy alhridoy force-pushed the docs-debug-print-perturbation branch from 55ae0ef to 06e0356 Compare August 30, 2025 20:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

jax.debug.print and jax.debug.breakpoint docs should mention perturbing the computation

2 participants