Skip to content

Commit 80295cd

Browse files
Flamefirepytorchmergebot
authored andcommitted
Adjust derivative calculation for removed linalg.solve optimization
1 parent a475f1a commit 80295cd

File tree

3 files changed

+5
-9
lines changed

3 files changed

+5
-9
lines changed

tools/autograd/derivatives.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1586,7 +1586,7 @@
15861586

15871587
- name: _linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor LU, Tensor pivots, Tensor info)
15881588
A, B: linalg_solve_backward(grad, result, A, LU, pivots, left, grad_input_mask[1])
1589-
result: "linalg_solve_jvp(A_t, B_t, result, LU, pivots, left, A_p.is_contiguous() && !A_p.is_complex())"
1589+
result: "linalg_solve_jvp(A_t, B_t, result, LU, pivots, left)"
15901590
output_differentiability: [True, False, False, False] # LU is an auxiliary tensor not exposed to the user
15911591

15921592
- name: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)

torch/csrc/autograd/FunctionsManual.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6033,8 +6033,7 @@ Tensor linalg_solve_jvp(
60336033
const Tensor& X,
60346034
const Tensor& LU,
60356035
const Tensor& pivots,
6036-
const bool left,
6037-
const bool use_A_T) {
6036+
const bool left) {
60386037
at::NoTF32Guard disable_tf32;
60396038
// For left=True (left=False is analogous)
60406039
// dX = A^{-1}(dB - dAX)
@@ -6056,8 +6055,7 @@ Tensor linalg_solve_jvp(
60566055
auto X_ = vector_to_matrix(X);
60576056
auto dB_ = vector_to_matrix(dB);
60586057
auto R_ = left ? dA.matmul(X_) : X_.matmul(dA);
6059-
auto dX_ =
6060-
at::linalg_lu_solve(LU, pivots, dB_ - R_, left, /*adjoint*/ use_A_T);
6058+
auto dX_ = at::linalg_lu_solve(LU, pivots, dB_ - R_, left);
60616059
return matrix_to_vector(dX_);
60626060
}
60636061

@@ -6095,9 +6093,8 @@ std::tuple<Tensor, Tensor> linalg_solve_backward(
60956093
if (at::GradMode::is_enabled()) {
60966094
gB_ = at::linalg_solve(A.mH(), vector_to_matrix(gX), left);
60976095
} else {
6098-
const auto use_A_T = A.is_contiguous() && !A.is_complex();
60996096
gB_ = at::linalg_lu_solve(
6100-
LU, pivots, vector_to_matrix(gX), left, /*adjoint*/ !use_A_T);
6097+
LU, pivots, vector_to_matrix(gX), left, /*adjoint*/ true);
61016098
}
61026099

61036100
Tensor gA_;

torch/csrc/autograd/FunctionsManual.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -896,8 +896,7 @@ Tensor linalg_solve_jvp(
896896
const Tensor& X,
897897
const Tensor& LU,
898898
const Tensor& pivots,
899-
const bool left,
900-
const bool use_A_T);
899+
const bool left);
901900
Tensor lu_unpack_backward(
902901
const Tensor& L_grad,
903902
const Tensor& U_grad,

0 commit comments

Comments
 (0)