Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@
# differentiable.
#
# Gradient expressions are standard C++ expressions operating on ATen
# variables. In a gradient expression, the following variables are in
# scope:
# variables. In a gradient expression, the following variables/functions
# are in scope:
#
# - 'grad', the gradient of the output (often spelled grad_output
# in Python) which we are going to left-multiply.
Expand Down Expand Up @@ -110,6 +110,23 @@
# destroy saved buffers if we know variables are not going to be retained,
# e.g., it is used by _cudnn_rnn
#
# - `wrap_opt_if`, is a 2-argument function that accepts a tensor
# variable and a boolean condition that dictates whether to save that
# variable in a graph. The result of this function is `c10::optional<Tensor>`,
# and it is `c10::nullopt` when the condition evalutes to `false`,
# otherwise it is the variable wrapped in `c10::optional<Tensor>`.
# For example, wrap_opt_if(var_0, grad_input_mask[1] || grad_input_mask[2])
# would mean that `var_0` is saved as long as the second (grad_input_mask[1])
# or the third (grad_input_mask[2]) argument requires gradients.
# Another interpretation of this expression would read as `var_0` is needed
# in the backward computation of the second or the third argument.
# NOTE: the usage of `var_i.requires_grad()` in the conditional expression
# is not supported, use `grad_input_mask[i]` instead.
# NOTE: `wrap_opt_if` could be used to prevent saving redundant variables
# with multi-output backward formulas.
# See https://github.com/pytorch/pytorch/issues/97575 for more details
# on the issue.
#
# If you need a complex expression, e.g., with local variables,
# write a _backward function in torch/csrc/autograd/FunctionsManual.cpp
# and invoke it from here. By the way, go read
Expand Down Expand Up @@ -1870,7 +1887,11 @@

# NN
- name: _trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor
i1, i2, i3: _trilinear_backward(grad, i1, i2, i3, expand1, expand2, expand3, sumdim, grad_input_mask)
i1, i2, i3: "_trilinear_backward(grad,
wrap_opt_if(i1, grad_input_mask[1] || grad_input_mask[2]),
wrap_opt_if(i2, grad_input_mask[0] || grad_input_mask[2]),
wrap_opt_if(i3, grad_input_mask[0] || grad_input_mask[1]),
expand1, expand2, expand3, sumdim, grad_input_mask)"
result: "_trilinear(i1_t, i2_p, i3_p, expand1, expand2, expand3, sumdim, unroll_dim) +
_trilinear(i1_p, i2_t, i3_p, expand1, expand2, expand3, sumdim, unroll_dim) +
_trilinear(i1_p, i2_p, i3_t, expand1, expand2, expand3, sumdim, unroll_dim)"
Expand Down
32 changes: 29 additions & 3 deletions tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# which will in turn dispatch back to VariableType for its
# differentiable subcomponents.
#
import re
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple, Union

from torchgen.api import cpp
Expand Down Expand Up @@ -1201,17 +1202,42 @@ def guard_for(arg: SavedAttribute) -> Optional[str]:
if len(used_in) != 1:
return None
derivative = used_in[0]

# Case with multioutput formulas
# TODO: process all derivative formulas!!!
if len(derivative.var_names) != 1:
return None
derivative_var_name = derivative.var_names[0]
wrap_opt_if_start = derivative.formula.find(
f"wrap_opt_if({arg.nctype.name}"
)
if wrap_opt_if_start == -1:
return None

wrap_opt_if_match = re.match(
rf"wrap_opt_if\({arg.nctype.name},(.*?)\)",
derivative.formula[wrap_opt_if_start:],
)
assert wrap_opt_if_match is not None

# Condition is between 'wrap_opt_if(var_name,' and ')'.
condition_slice = slice(len(rf"wrap_opt_if\({arg.nctype.name},"), -1)
wrap_opt_if_condition = wrap_opt_if_match.group(0)[
condition_slice
].strip()
# replace 'grad_input_mask[num]' with 'grad_fn->should_compute_output(num)'
wrap_opt_if_condition = re.sub(
r"grad_input_mask\[(\d+)\]",
r"grad_fn->should_compute_output(\1)",
wrap_opt_if_condition,
)
return f"{wrap_opt_if_condition}"

# Figure out the offset of the edge that uses this variable
derivative_var_name = derivative.var_names[0]
for edge_off, a in enumerate(args_with_derivatives):
if a.name == derivative_var_name:
break
else:
raise AssertionError()

return f"grad_fn->should_compute_output({edge_off})"

if is_inplace_foreach:
Expand Down
12 changes: 6 additions & 6 deletions torch/csrc/autograd/FunctionsManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4870,9 +4870,9 @@ infinitely_differentiable_native_group_norm_backward(

std::tuple<Tensor, Tensor, Tensor> _trilinear_backward(
const Tensor& grad_out,
const Tensor& i1,
const Tensor& i2,
const Tensor& i3,
const c10::optional<Tensor>& i1,
const c10::optional<Tensor>& i2,
const c10::optional<Tensor>& i3,
IntArrayRef expand1,
IntArrayRef expand2,
IntArrayRef expand3,
Expand All @@ -4882,13 +4882,13 @@ std::tuple<Tensor, Tensor, Tensor> _trilinear_backward(
if (grad_out.defined()) {
if (grad_mask[0])
grad_i1 =
at::_trilinear(grad_out, i2, i3, sumdim, expand2, expand3, expand1);
at::_trilinear(grad_out, *i2, *i3, sumdim, expand2, expand3, expand1);
if (grad_mask[1])
grad_i2 =
at::_trilinear(i1, grad_out, i3, expand1, sumdim, expand3, expand2);
at::_trilinear(*i1, grad_out, *i3, expand1, sumdim, expand3, expand2);
if (grad_mask[2])
grad_i3 =
at::_trilinear(i1, i2, grad_out, expand1, expand2, sumdim, expand3);
at::_trilinear(*i1, *i2, grad_out, expand1, expand2, sumdim, expand3);
}
return std::tuple<Tensor, Tensor, Tensor>(grad_i1, grad_i2, grad_i3);
}
Expand Down
11 changes: 8 additions & 3 deletions torch/csrc/autograd/FunctionsManual.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ Tensor toNonOptFwGrad(const c10::optional<Tensor>& t);
Tensor toNonOptPrimal(const c10::optional<Tensor>& t);
Tensor toNonOptTensor(const c10::optional<Tensor>& t);

inline c10::optional<Tensor> wrap_opt_if(const Tensor& t, const bool cond) {
using OptTensor = c10::optional<Tensor>;
return cond ? OptTensor(t) : static_cast<OptTensor>(c10::nullopt);
}

Comment on lines +41 to +45
Copy link
Collaborator Author

@nikitaved nikitaved Jun 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the context: this one is used in the next PR up in the stack for sparse_sampled_addmm_backward.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If codegen did its job correctly, we would always get an undefined tensor t here right? Maybe we can assert for that here.

Copy link
Collaborator Author

@nikitaved nikitaved Jun 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, not really, unfortunately. This code is being run at backward compute. But we can assert inside backward implementations for sure to test both conditions and savings.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh whoops, good point.

Tensor apply_loss_reduction(const Tensor& unreduced, int64_t reduction);
bool any_variable_defined(const variable_list& variables);
void copy_range(variable_list& out, IndexRange range, const at::Tensor& t);
Expand Down Expand Up @@ -639,9 +644,9 @@ std::tuple<Tensor, Tensor> linalg_solve_triangular_backward(
std::array<bool, 2> output_mask);
std::tuple<Tensor, Tensor, Tensor> _trilinear_backward(
const Tensor& grad_out,
const Tensor& i1,
const Tensor& i2,
const Tensor& i3,
const c10::optional<Tensor>& i1,
const c10::optional<Tensor>& i2,
const c10::optional<Tensor>& i3,
IntArrayRef expand1,
IntArrayRef expand2,
IntArrayRef expand3,
Expand Down