-
Notifications
You must be signed in to change notification settings - Fork 25.2k
sampled_addmm: backward performance improvements #103544
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/103544
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 21b7dc1: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
No need to do double `sparse_mask`, let's squash everything into one call! cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
const auto grad_projected = grad.sparse_mask(self); | ||
const auto self_requires_grad = grad_input_mask[0]; | ||
const auto mat1_requires_grad = grad_input_mask[1]; | ||
const auto mat2_requires_grad = grad_input_mask[2]; | ||
return std::make_tuple( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albanD , at this point it is for sure that at least one grad_input_mask
is true, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it should be safe to assume that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure we want to do that because this will regress memory usage :/
In particular, if only self requires grad, m1 and m2 won't be saved today. But with this new code, they will always be saved.
I think there was a discussion on fixing this, do you remember where it is @soulitzer ?
Maybe something like sparse_sampled_addmm_backward(grad, self, optional_save_if(mat2.requires_grad(), mat1), optional_save_if(mat1.requires_grad(), mat2), alpha, beta, grad_input_mask)
and make the backward take optional<> args.
We then need to use our pattern matcher for replacement to make a smart decision on saving.
tools/autograd/derivatives.yaml
Outdated
self: maybe_multiply(grad, beta.conj()) | ||
mat1: maybe_multiply(grad.sparse_mask(self).mm(mat2.mH()), alpha.conj()) | ||
mat2: maybe_multiply(mat1.mH().mm(grad.sparse_mask(self)), alpha.conj()) | ||
self, mat1, mat2: sparse_sampled_addmm_backward(grad, self, mat1, mat2, alpha, beta, grad_input_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid any confusion, the gradient wrt self
is still incorrect. It is fixed in #103548.
No need to do double `sparse_mask`, let's squash everything into one call! cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
No need to do double `sparse_mask`, let's squash everything into one call! cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
Here's the issue about multi-output functions saving unnecessary tensors #97575 |
@albanD , @soulitzer , would the current code be sufficient for now? Or is there a way to tell autogen to produce and reuse some intermediaries? |
There's no way to tell the codegen to do this currently unfortunately There needs to be an API to specify what needs to be saved under what conditions, and then some code gen updates to translate that information into extra logic in the VariableType kernel And what alban is proposal here could be promising:
|
What @albanD suggested is already applied here :) |
So if we land this PR as-is we're just trading off memory for compute. This trade off could be worth it though if we could say that in practice all the inputs tend to have requires grad. |
Oh I believe that alban is suggesting is something slightly different. The grad_input_mask is a quantity computed when backward is run, so that would not influence how things are saved during forward. |
Ah, I see, I was not aware of that, will change than. |
No need to do double `sparse_mask`, let's squash everything into one call! cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
No need to do double `sparse_mask`, let's squash everything into one call! cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
No need to do double `sparse_mask`, let's squash everything into one call! cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
No need to do double `sparse_mask`, let's squash everything into one call! cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
No need to do double `sparse_mask`, let's squash everything into one call! cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
No need to do double `sparse_mask`, let's squash everything into one call! cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
No need to do double `sparse_mask`, let's squash everything into one call! cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
No need to do double `sparse_mask`, let's squash everything into one call! cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
Good catch! |
No need to do double `sparse_mask`, let's squash everything into one call! This PR exercises #103750, so here is an autogened code for the backward pass. ``` at::Tensor sparse_sampled_addmm(c10::DispatchKeySet ks, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) { auto& self_ = unpack(self, "self", 0); auto& mat1_ = unpack(mat1, "mat1", 1); auto& mat2_ = unpack(mat2, "mat2", 2); [[maybe_unused]] auto _any_requires_grad = compute_requires_grad( self, mat1, mat2 ); std::shared_ptr<SparseSampledAddmmBackward0> grad_fn; if (_any_requires_grad) { grad_fn = std::shared_ptr<SparseSampledAddmmBackward0>(new SparseSampledAddmmBackward0(), deleteNode); grad_fn->set_next_edges(collect_next_edges( self, mat1, mat2 )); grad_fn->alpha = alpha; grad_fn->beta = beta; if (grad_fn->should_compute_output(2)) { grad_fn->mat1_ = SavedVariable(mat1, false); } if (grad_fn->should_compute_output(1)) { grad_fn->mat2_ = SavedVariable(mat2, false); } grad_fn->self_ = SavedVariable(self, false); } ``` As you can see, we do not save tensors unless needed. cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
No need to do double `sparse_mask`, let's squash everything into one call! This PR exercises #103750, so here is an autogened code for the backward pass. ``` at::Tensor sparse_sampled_addmm(c10::DispatchKeySet ks, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) { auto& self_ = unpack(self, "self", 0); auto& mat1_ = unpack(mat1, "mat1", 1); auto& mat2_ = unpack(mat2, "mat2", 2); [[maybe_unused]] auto _any_requires_grad = compute_requires_grad( self, mat1, mat2 ); std::shared_ptr<SparseSampledAddmmBackward0> grad_fn; if (_any_requires_grad) { grad_fn = std::shared_ptr<SparseSampledAddmmBackward0>(new SparseSampledAddmmBackward0(), deleteNode); grad_fn->set_next_edges(collect_next_edges( self, mat1, mat2 )); grad_fn->alpha = alpha; grad_fn->beta = beta; if (grad_fn->should_compute_output(2)) { grad_fn->mat1_ = SavedVariable(mat1, false); } if (grad_fn->should_compute_output(1)) { grad_fn->mat2_ = SavedVariable(mat2, false); } grad_fn->self_ = SavedVariable(self, false); } ``` As you can see, we do not save tensors unless needed. cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
No need to do double `sparse_mask`, let's squash everything into one call! This PR exercises #103750, so here is an autogened code for the backward pass. ``` at::Tensor sparse_sampled_addmm(c10::DispatchKeySet ks, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) { auto& self_ = unpack(self, "self", 0); auto& mat1_ = unpack(mat1, "mat1", 1); auto& mat2_ = unpack(mat2, "mat2", 2); [[maybe_unused]] auto _any_requires_grad = compute_requires_grad( self, mat1, mat2 ); std::shared_ptr<SparseSampledAddmmBackward0> grad_fn; if (_any_requires_grad) { grad_fn = std::shared_ptr<SparseSampledAddmmBackward0>(new SparseSampledAddmmBackward0(), deleteNode); grad_fn->set_next_edges(collect_next_edges( self, mat1, mat2 )); grad_fn->alpha = alpha; grad_fn->beta = beta; if (grad_fn->should_compute_output(2)) { grad_fn->mat1_ = SavedVariable(mat1, false); } if (grad_fn->should_compute_output(1)) { grad_fn->mat2_ = SavedVariable(mat2, false); } grad_fn->self_ = SavedVariable(self, false); } ``` As you can see, we do not save tensors unless needed. cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
No need to do double `sparse_mask`, let's squash everything into one call! This PR exercises #103750, so here is an autogened code for the backward pass. ``` at::Tensor sparse_sampled_addmm(c10::DispatchKeySet ks, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) { auto& self_ = unpack(self, "self", 0); auto& mat1_ = unpack(mat1, "mat1", 1); auto& mat2_ = unpack(mat2, "mat2", 2); [[maybe_unused]] auto _any_requires_grad = compute_requires_grad( self, mat1, mat2 ); std::shared_ptr<SparseSampledAddmmBackward0> grad_fn; if (_any_requires_grad) { grad_fn = std::shared_ptr<SparseSampledAddmmBackward0>(new SparseSampledAddmmBackward0(), deleteNode); grad_fn->set_next_edges(collect_next_edges( self, mat1, mat2 )); grad_fn->alpha = alpha; grad_fn->beta = beta; if (grad_fn->should_compute_output(2)) { grad_fn->mat1_ = SavedVariable(mat1, false); } if (grad_fn->should_compute_output(1)) { grad_fn->mat2_ = SavedVariable(mat2, false); } grad_fn->self_ = SavedVariable(self, false); } ``` As you can see, we do not save tensors unless needed. cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
No need to do double `sparse_mask`, let's squash everything into one call! This PR exercises #103750, so here is an autogened code for the backward pass. ``` at::Tensor sparse_sampled_addmm(c10::DispatchKeySet ks, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) { auto& self_ = unpack(self, "self", 0); auto& mat1_ = unpack(mat1, "mat1", 1); auto& mat2_ = unpack(mat2, "mat2", 2); [[maybe_unused]] auto _any_requires_grad = compute_requires_grad( self, mat1, mat2 ); std::shared_ptr<SparseSampledAddmmBackward0> grad_fn; if (_any_requires_grad) { grad_fn = std::shared_ptr<SparseSampledAddmmBackward0>(new SparseSampledAddmmBackward0(), deleteNode); grad_fn->set_next_edges(collect_next_edges( self, mat1, mat2 )); grad_fn->alpha = alpha; grad_fn->beta = beta; if (grad_fn->should_compute_output(2)) { grad_fn->mat1_ = SavedVariable(mat1, false); } if (grad_fn->should_compute_output(1)) { grad_fn->mat2_ = SavedVariable(mat2, false); } grad_fn->self_ = SavedVariable(self, false); } ``` As you can see, we do not save tensors unless needed. cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 [ghstack-poisoned]
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
No need to do double
sparse_mask
, let's squash everything into one call!This PR exercises #103750, so here is an autogened code for the backward pass.
As you can see, we do not save tensors unless needed.
Stack from ghstack (oldest at bottom):
cc @alexsamardzic @pearu @cpuhrsch @amjames @bhosmer @ezyang @albanD @zou3519 @gqchen @soulitzer @lezcano @Varal7