Skip to content

Conversation

tom-pollak
Copy link
Contributor

Summary

The documented reference implementation of SDPA doesn't numerically match the MATH. This causes confusion when testing numerical accuracy of kernels / code vs MATH.

Changes

Updated the reference implementation to match MATH's actual behavior. The key corrections are:

  • The MATH backend pre-scales both query and key tensors before matmul for numerical stability, rather than scaling after the matmul operation.
  • The MATH backend internally upcasts to float32 for fp16/bf16 inputs, then converts back to original dtype at the end.

Added regression test test_reference_implementation_bitwise_match_math_backend. Test verifies exact bitwise match (rtol=0, atol=0) between the reference implementation and MATH.

MATH backend scales Q K pre-softmax, leading to numerical differences
when comparing the ref impl with MATH.

Now matches SDPBackend.MATH with `rotl=0., atol=0.`
Copy link

pytorch-bot bot commented Sep 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163508

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit 7cb5f4f with merge base 96a3afb (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@tom-pollak
Copy link
Contributor Author

@pytorchbot label "module: sdpa"

@pytorch-bot pytorch-bot bot added the module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion label Sep 22, 2025
@tom-pollak
Copy link
Contributor Author

@pytorchbot label "release notes: nn"

@pytorch-bot pytorch-bot bot added the release notes: nn release notes category label Sep 22, 2025
Just whitespace, but no longer a copy-paste from docs
@albanD albanD removed their request for review September 22, 2025 14:13
@jbschlosser jbschlosser requested a review from drisspg September 22, 2025 15:52
Copy link
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! I agree with the fixes and it's nice to have bitwise-accurate validation against the math backend via tests.

One concern I have is that the test won't catch updates to the reference implementation in the docs. Running this via the doctest mechanism would address this. cc @svekars for insight on how to ensure this validation happens during doctest time

Don't think they are seeded the same way on some archs, leading to
different dropout
@tom-pollak
Copy link
Contributor Author

Thanks! Agree, originally I had a doctest but I thought it might pollute the page since it would have to be immediately below the code block. This way it should still catch regressions to the MATH kernel, which seems more likely to breaking change. Happy to go either way though.

@jbschlosser
Copy link
Contributor

Agree, originally I had a doctest but I thought it might pollute the page since it would have to be immediately below the code block.

If it's just a few lines to compare reference vs. math, my opinion is that this doesn't pollute the page too much. In fact, it makes it very clear that the reference in the docs is what users should expect from the math backend.

@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 22, 2025
@tom-pollak
Copy link
Contributor Author

tom-pollak commented Sep 23, 2025

Looking into this: I'm not sure there's a good way to integrate with doctest. SDPA is not defined in a def, xdoctest doesn’t discover it when scanning (even with --analysis=dynamic).

The current sdpa examples aren't actually run with xdoctest, you can test this with:

export XDOCTEST_GLOBAL_EXEC="from torch import nn\nimport torch.nn.functional as F\nimport torch"
xdoctest -m torch.nn.functional --analysis dynamic
<examples are run, but not sdpa>

I could add a "fake" doctest in, but that might be more confusing, or I think we'd need to refactor.

@tom-pollak tom-pollak requested a review from drisspg September 24, 2025 10:21
@tom-pollak
Copy link
Contributor Author

@drisspg Ok to be merged?

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks

@tom-pollak
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 2, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-jammy-cuda12.8-py3.10-gcc11 / test (default, 2, 5, linux.g6.4xlarge.experimental.nvidia.gpu)

Details for Dev Infra team Raised by workflow job

@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Oct 2, 2025
@tom-pollak
Copy link
Contributor Author

@drisspg my bad, needed to put temp_mask on the attn_bias device. Fixed now!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion open source release notes: nn release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

6 participants