-
Notifications
You must be signed in to change notification settings - Fork 25.2k
[SDPA] Add an optional scale kwarg #95259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/95259
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 2 PendingAs of commit ac517f3: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
6f295cc
to
143b654
Compare
@@ -3,12 +3,13 @@ | |||
#include <c10/macros/Export.h> | |||
#include <ATen/native/DispatchStub.h> | |||
#include <ATen/native/transformers/attention.h> | |||
#include "c10/util/Optional.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix
@@ -247,6 +250,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> _efficient_attention_backward( | |||
p.grad_key_ptr = (scalar_t*)grad_k.data_ptr(); | |||
p.grad_value_ptr = (scalar_t*)grad_v.data_ptr(); | |||
p.delta_ptr = (float*)delta.data_ptr(); | |||
p.scale = scale.has_value() ? 1.0f / scale.value() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems easier to not make scale optional and do this ternary dance everywhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay figured out why the dance,
Scale needs to be the last arg, but this would come after args with default args. So need a good default hence -> error: missing default argument on parameter 'scale'
41fd53d
to
a7cb685
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure that the dynamo graph break mechanism handles these new kwarg gracefully. It should its kwarg only but update the tests
1f3e274
to
b9f32c0
Compare
bcfa9c4
to
ce9811b
Compare
637a503
to
cd15ae8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See leftover comments. Otherwise this seems good to go.
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
# Summary This PR adds an optional kwarg to torch torch.nn.functional.scaled_dot_product_attention() The new kwarg is a scaling factor that is applied after the [email protected] step of the computation. Made updates to the efficient kernel to support but flash and math were minimally updated to support as well. Will reduce the complexity of: #94729 and has been asked for by a couple of users. # Review Highlights - As far as I know I did this the correct way and this both BC and FC compliant. However I always seem to break internal workloads so I would love if someone can advice I did this right? - I named the optional arg 'scale'. This is probably dumb and I should name it 'scale_factor'. I will make this change but this is annoying and it will require someone thinking we should rename. - 'scale' is interpreted as `[email protected] * (scale)` Pull Request resolved: pytorch/pytorch#95259 Approved by: https://github.com/cpuhrsch
# Summary This PR adds an optional kwarg to torch torch.nn.functional.scaled_dot_product_attention() The new kwarg is a scaling factor that is applied after the [email protected] step of the computation. Made updates to the efficient kernel to support but flash and math were minimally updated to support as well. Will reduce the complexity of: #94729 and has been asked for by a couple of users. # Review Highlights - As far as I know I did this the correct way and this both BC and FC compliant. However I always seem to break internal workloads so I would love if someone can advice I did this right? - I named the optional arg 'scale'. This is probably dumb and I should name it 'scale_factor'. I will make this change but this is annoying and it will require someone thinking we should rename. - 'scale' is interpreted as `[email protected] * (scale)` Pull Request resolved: pytorch/pytorch#95259 Approved by: https://github.com/cpuhrsch
# Summary This PR adds an optional kwarg to torch torch.nn.functional.scaled_dot_product_attention() The new kwarg is a scaling factor that is applied after the [email protected] step of the computation. Made updates to the efficient kernel to support but flash and math were minimally updated to support as well. Will reduce the complexity of: pytorch#94729 and has been asked for by a couple of users. # Review Highlights - As far as I know I did this the correct way and this both BC and FC compliant. However I always seem to break internal workloads so I would love if someone can advice I did this right? - I named the optional arg 'scale'. This is probably dumb and I should name it 'scale_factor'. I will make this change but this is annoying and it will require someone thinking we should rename. - 'scale' is interpreted as `[email protected] * (scale)` Pull Request resolved: pytorch#95259 Approved by: https://github.com/cpuhrsch
Summary
This PR adds an optional kwarg to torch torch.nn.functional.scaled_dot_product_attention()
The new kwarg is a scaling factor that is applied after the [email protected] step of the computation. Made updates to the efficient kernel to support but flash and math were minimally updated to support as well.
Will reduce the complexity of: #94729 and has been asked for by a couple of users.
Review Highlights
[email protected] * (scale)
cc @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire