Skip to content

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Feb 22, 2023

Summary

This PR adds an optional kwarg to torch torch.nn.functional.scaled_dot_product_attention()
The new kwarg is a scaling factor that is applied after the [email protected] step of the computation. Made updates to the efficient kernel to support but flash and math were minimally updated to support as well.

Will reduce the complexity of: #94729 and has been asked for by a couple of users.

Review Highlights

  • As far as I know I did this the correct way and this both BC and FC compliant. However I always seem to break internal workloads so I would love if someone can advice I did this right?
  • I named the optional arg 'scale'. This is probably dumb and I should name it 'scale_factor'. I will make this change but this is annoying and it will require someone thinking we should rename.
  • 'scale' is interpreted as [email protected] * (scale)

cc @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 22, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/95259

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 2 Pending

As of commit ac517f3:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@drisspg drisspg force-pushed the sdpa_optional_scale_kwarg branch 2 times, most recently from 6f295cc to 143b654 Compare February 22, 2023 03:00
@drisspg drisspg marked this pull request as ready for review February 22, 2023 03:03
@drisspg drisspg added release notes: nn release notes category topic: new features topic category labels Feb 22, 2023
@drisspg drisspg requested a review from cpuhrsch February 22, 2023 03:10
@@ -3,12 +3,13 @@
#include <c10/macros/Export.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/transformers/attention.h>
#include "c10/util/Optional.h"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

@@ -247,6 +250,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> _efficient_attention_backward(
p.grad_key_ptr = (scalar_t*)grad_k.data_ptr();
p.grad_value_ptr = (scalar_t*)grad_v.data_ptr();
p.delta_ptr = (float*)delta.data_ptr();
p.scale = scale.has_value() ? 1.0f / scale.value()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems easier to not make scale optional and do this ternary dance everywhere

Copy link
Contributor Author

@drisspg drisspg Feb 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay figured out why the dance,
Scale needs to be the last arg, but this would come after args with default args. So need a good default hence -> error: missing default argument on parameter 'scale'

@drisspg drisspg force-pushed the sdpa_optional_scale_kwarg branch from 41fd53d to a7cb685 Compare February 23, 2023 03:08
Copy link
Contributor Author

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure that the dynamo graph break mechanism handles these new kwarg gracefully. It should its kwarg only but update the tests

@drisspg drisspg force-pushed the sdpa_optional_scale_kwarg branch 2 times, most recently from 1f3e274 to b9f32c0 Compare March 3, 2023 18:53
@drisspg drisspg force-pushed the sdpa_optional_scale_kwarg branch from bcfa9c4 to ce9811b Compare March 5, 2023 18:56
@drisspg drisspg force-pushed the sdpa_optional_scale_kwarg branch from 637a503 to cd15ae8 Compare March 7, 2023 19:40
@drisspg drisspg requested a review from cpuhrsch March 8, 2023 00:04
Copy link
Contributor

@cpuhrsch cpuhrsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See leftover comments. Otherwise this seems good to go.

@drisspg drisspg added the topic: improvements topic category label Mar 8, 2023
@drisspg
Copy link
Contributor Author

drisspg commented Mar 8, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 8, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 12, 2023
# Summary
This PR adds an optional kwarg to torch torch.nn.functional.scaled_dot_product_attention()
The new kwarg is a scaling factor that is applied after the [email protected] step of the computation. Made updates to the efficient kernel to support but flash and math were minimally updated to support as well.

Will reduce the complexity of: #94729 and has been asked for by a couple of users.

# Review Highlights
- As far as I know I did this the correct way and this both BC and FC compliant. However I always seem to break internal workloads so I would love if someone can advice I did this right?
- I named the optional arg 'scale'. This is probably dumb and I should name it 'scale_factor'. I will make this change but this is annoying and it will require someone thinking we should rename.
- 'scale' is interpreted as `[email protected] * (scale)`

Pull Request resolved: pytorch/pytorch#95259
Approved by: https://github.com/cpuhrsch
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 12, 2023
# Summary
This PR adds an optional kwarg to torch torch.nn.functional.scaled_dot_product_attention()
The new kwarg is a scaling factor that is applied after the [email protected] step of the computation. Made updates to the efficient kernel to support but flash and math were minimally updated to support as well.

Will reduce the complexity of: #94729 and has been asked for by a couple of users.

# Review Highlights
- As far as I know I did this the correct way and this both BC and FC compliant. However I always seem to break internal workloads so I would love if someone can advice I did this right?
- I named the optional arg 'scale'. This is probably dumb and I should name it 'scale_factor'. I will make this change but this is annoying and it will require someone thinking we should rename.
- 'scale' is interpreted as `[email protected] * (scale)`

Pull Request resolved: pytorch/pytorch#95259
Approved by: https://github.com/cpuhrsch
ydwu4 added a commit to ydwu4/pytorch that referenced this pull request Mar 13, 2023
# Summary
This PR adds an optional kwarg to torch torch.nn.functional.scaled_dot_product_attention()
The new kwarg is a scaling factor that is applied after the [email protected] step of the computation. Made updates to the efficient kernel to support but flash and math were minimally updated to support as well.

Will reduce the complexity of: pytorch#94729 and has been asked for by a couple of users.

# Review Highlights
- As far as I know I did this the correct way and this both BC and FC compliant. However I always seem to break internal workloads so I would love if someone can advice I did this right?
- I named the optional arg 'scale'. This is probably dumb and I should name it 'scale_factor'. I will make this change but this is annoying and it will require someone thinking we should rename.
- 'scale' is interpreted as `[email protected] * (scale)`

Pull Request resolved: pytorch#95259
Approved by: https://github.com/cpuhrsch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo release notes: nn release notes category topic: improvements topic category topic: new features topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants