Skip to content

Commit 5cf3a99

Browse files
nikitavedpytorchmergebot
authored andcommitted
sampled_addmm: backward performance improvements (#103544)
No need to do double `sparse_mask`, let's squash everything into one call! This PR exercises #103750, so here is an autogened code for the backward pass. ``` at::Tensor sparse_sampled_addmm(c10::DispatchKeySet ks, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) { auto& self_ = unpack(self, "self", 0); auto& mat1_ = unpack(mat1, "mat1", 1); auto& mat2_ = unpack(mat2, "mat2", 2); [[maybe_unused]] auto _any_requires_grad = compute_requires_grad( self, mat1, mat2 ); std::shared_ptr<SparseSampledAddmmBackward0> grad_fn; if (_any_requires_grad) { grad_fn = std::shared_ptr<SparseSampledAddmmBackward0>(new SparseSampledAddmmBackward0(), deleteNode); grad_fn->set_next_edges(collect_next_edges( self, mat1, mat2 )); grad_fn->alpha = alpha; grad_fn->beta = beta; if (grad_fn->should_compute_output(2)) { grad_fn->mat1_ = SavedVariable(mat1, false); } if (grad_fn->should_compute_output(1)) { grad_fn->mat2_ = SavedVariable(mat2, false); } grad_fn->self_ = SavedVariable(self, false); } ``` As you can see, we do not save tensors unless needed. Pull Request resolved: #103544 Approved by: https://github.com/soulitzer
1 parent 148960b commit 5cf3a99

File tree

3 files changed

+39
-3
lines changed

3 files changed

+39
-3
lines changed

tools/autograd/derivatives.yaml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2486,9 +2486,11 @@
24862486
result: replication_pad3d_backward_symint(grad_output_t, self_p, padding)
24872487

24882488
- name: sparse_sampled_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
2489-
self: maybe_multiply(grad, beta.conj())
2490-
mat1: maybe_multiply(grad.sparse_mask(self).mm(mat2.mH()), alpha.conj())
2491-
mat2: maybe_multiply(mat1.mH().mm(grad.sparse_mask(self)), alpha.conj())
2489+
self, mat1, mat2: "sparse_sampled_addmm_backward(grad,
2490+
self,
2491+
wrap_opt_if(mat1, grad_input_mask[2]),
2492+
wrap_opt_if(mat2, grad_input_mask[1]),
2493+
alpha, beta, grad_input_mask)"
24922494

24932495
- name: _sparse_mm_reduce_impl(Tensor self, Tensor other, str reduce) -> (Tensor, Tensor)
24942496
output_differentiability: [True, False]

torch/csrc/autograd/FunctionsManual.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,6 +1475,32 @@ static Tensor sparse_mask_like_grad(const Tensor& x, const Tensor& gx) {
14751475
}
14761476
}
14771477

1478+
std::tuple<Tensor, Tensor, Tensor> sparse_sampled_addmm_backward(
1479+
const Tensor& grad,
1480+
const Tensor& self,
1481+
const c10::optional<Tensor>& mat1,
1482+
const c10::optional<Tensor>& mat2,
1483+
const Scalar& alpha,
1484+
const Scalar& beta,
1485+
const std::array<bool, 3>& grad_input_mask) {
1486+
if (!grad.defined()) {
1487+
return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
1488+
}
1489+
1490+
const auto grad_projected = grad.sparse_mask(self);
1491+
const auto self_requires_grad = grad_input_mask[0];
1492+
const auto mat1_requires_grad = grad_input_mask[1];
1493+
const auto mat2_requires_grad = grad_input_mask[2];
1494+
return std::make_tuple(
1495+
self_requires_grad ? maybe_multiply(grad, beta.conj()) : Tensor{},
1496+
mat1_requires_grad
1497+
? maybe_multiply(grad_projected.mm(mat2->mH()), alpha.conj())
1498+
: Tensor{},
1499+
mat2_requires_grad
1500+
? maybe_multiply(mat1->mH().mm(grad_projected), alpha.conj())
1501+
: Tensor{});
1502+
}
1503+
14781504
Tensor sparse_sparse_matmul_backward(
14791505
const Tensor& grad,
14801506
const Tensor& a,

torch/csrc/autograd/FunctionsManual.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,14 @@ at::Tensor mm_mat1_sparse_backward(
297297
const at::Tensor& mat1,
298298
const at::Tensor& mat2,
299299
const at::Scalar& alpha);
300+
std::tuple<Tensor, Tensor, Tensor> sparse_sampled_addmm_backward(
301+
const Tensor& grad,
302+
const Tensor& self,
303+
const c10::optional<Tensor>& mat1,
304+
const c10::optional<Tensor>& mat2,
305+
const Scalar& alpha,
306+
const Scalar& beta,
307+
const std::array<bool, 3>& grad_input_mask);
300308
at::Tensor sparse_sparse_matmul_backward(
301309
const at::Tensor& grad,
302310
const at::Tensor& mat1,

0 commit comments

Comments
 (0)