Skip to content

Conversation

yanbing-j
Copy link
Collaborator

@yanbing-j yanbing-j commented Apr 17, 2023

This PR is to add support of sum.dim_IntList for Sparse Tensor, which is exposed in #98796.

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10

@pytorch-bot pytorch-bot bot added the release notes: sparse release notes category label Apr 17, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Apr 17, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/99292

Note: Links to docs will display an error until the docs builds have been completed.

✅ 3 Unrelated Failures

As of commit e45db7a:

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@yanbing-j yanbing-j requested a review from mingfeima April 17, 2023 07:27
@yanbing-j yanbing-j force-pushed the yanbing/sparse_sum_csr branch from 5cc3191 to e64dd94 Compare April 18, 2023 02:36
@mingfeima mingfeima requested a review from pearu April 18, 2023 05:24
@@ -1046,7 +1046,7 @@ Tensor reduce_sparse_csr_dim0_cpu_template(const Tensor& sparse, ReductionOp rop

AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "reduce_sparse_csr_dim0_cpu_indices",
[&]() {
index_t* columns_map_ptr = columns_map.data_ptr<index_t>();
int64_t* columns_map_ptr = columns_map.data_ptr<int64_t>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Could you eliminate using AT_DISPATCH_INDEX_TYPES as index_t is not used in this block anymore?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comments. I find that simply changing index_t to int64_t here is not that appropriate. The root cause should be index_type should be align with columns_map's dtype, not col_indices's dtype. Have updated.

Copy link
Collaborator

@mingfeima mingfeima left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test failures seem to be real, please have them fixed.

@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 25, 2023
@yanbing-j yanbing-j force-pushed the yanbing/sparse_sum_csr branch 2 times, most recently from 277d51b to 9995d22 Compare May 9, 2023 05:23
@yanbing-j yanbing-j requested review from mruberry and ngimel as code owners May 9, 2023 09:20
@yanbing-j
Copy link
Collaborator Author

Hi @pearu , I see your PR #100391 will raise error in 'aten::sum.IntList_out' with arguments from the 'SparseCsr(CPU|CUDA)', which has conflicts to current PR. Therefore, I remove the part of code, remains error of bsr and bsc.

But I'm confused to how to pass your UT. No matter I raise Error or add warning in bsr and bsc part, self.assertFalse(isinstance(out, type(NotImplemented))) will fail. Could you please give me some advices? And also have a look at this PR? Thank you!

Copy link
Collaborator

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @yanbing-j, for this! I did my initial review of the PR and found that there exist more efficient paths to compute the sum of CSC tensors.

I'll address your testing questions a bit later.

@@ -2165,5 +2166,21 @@ Tensor sum_sparse_coo(const Tensor& self, at::OptionalIntArrayRef dim, bool keep
return result;
}

Tensor sum_sparse_csr(const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, c10::optional<ScalarType> dtype) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SparseCsrCPU/CUDA dispatch keys represent all sparse compressed layouts: CSR, BSR, CSC, BSC. The same holds for the sum_sparse_csr function so I suggest:

Suggested change
Tensor sum_sparse_csr(const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, c10::optional<ScalarType> dtype) {
Tensor sum_sparse_compressed(const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, c10::optional<ScalarType> dtype) {

// bit different in the second parameters `dim`, which causes the conversion of `dim`
// to call into `_sparse_csr_sum`. Align the signatures would be a better choice.
TORCH_CHECK(dim.has_value(),"dim has no value, cannot be used in sum.dim_IntList");
if (self.is_sparse_csr()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use:

Suggested change
if (self.is_sparse_csr()) {
auto layout = self.layout();
if (layout == kSparseCsr) {

Comment on lines 2177 to 2178
Tensor new_self = self.to_dense().to_sparse_csr();
return at::_sparse_csr_sum(new_self, *dim, true, dtype);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conversion to a strided tensor is unnecessary, using self.to_sparse_csr() should be sufficient and is more efficient. That said, I think no conversions are necessary because we have the following invariants:

batch_dim = csc.dim() - csc.dense_dim() - csc.sparse_dim()
csc.layout == torch.sparse_csc
csc.transpose(batch_dim, batch_dim + 1).layout == torch.sparse_csr
torch.sum(csc, dim=dim, keepdim=True) == torch.sum(csc.transpose(batch_dim, batch_dim+1),
      dim=(*dim[:batch_dim], dim[batch_dim+1], dim[batch_dim], *dim[batch_dim+2:]),
      keepdim=True).transpose(batch_dim, batch_dim+1)

So, I suggest using the following method for computing reductions on a CSC tensor:

Suggested change
Tensor new_self = self.to_dense().to_sparse_csr();
return at::_sparse_csr_sum(new_self, *dim, true, dtype);
auto batch_dim = csc.dim() - csc.dense_dim() - csc.sparse_dim();
auto swapped_dim = ...; // a copy of `*dim` where `batch_dim` and `batch_dim+1`- th elements are swapped
return at::_sparse_csr_sum(self.transpose(batch_dim, batch_dim+1), swapped_dim, true, dtype).transpose(batch_dim, batch_dim+1);

Comment on lines 2180 to 2181
LOG(WARNING) << "Only SparseCsr and SparseCSC are supported for now";
return Tensor();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
LOG(WARNING) << "Only SparseCsr and SparseCSC are supported for now";
return Tensor();
TORCH_CHECK(false, "sum expected input with strided, sparse_csr, or sparse_csc layouts, got layout ", layout);
return Tensor();

@@ -5389,6 +5389,7 @@
dispatch:
NestedTensorCPU: NestedTensor_sum_dim_CPU
SparseCPU, SparseCUDA: sum_sparse_coo
SparseCsrCPU, SparseCsrCUDA: sum_sparse_csr # TODO: Align the signature of sum.dim_IntList and _sparse_csr_sum.dim_dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
SparseCsrCPU, SparseCsrCUDA: sum_sparse_csr # TODO: Align the signature of sum.dim_IntList and _sparse_csr_sum.dim_dtype
SparseCsrCPU, SparseCsrCUDA: sum_sparse_compressed # TODO: Align the signature of sum.dim_IntList and _sparse_csr_sum.dim_dtype

Comment on lines 1068 to 1074
new_values_acc_ptr[col] = rop(new_values_acc_ptr[col], static_cast<opmath_t>(val));
}
for (int64_t i = 0; i < nnz; i++) {
if (need_acc) {
new_values_ptr[i] = static_cast<scalar_t>(new_values_acc_ptr[i]);
} else {
new_values_ptr[i] = new_values_acc_ptr[i];
}
}
});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimization nit continued:

Suggested change
new_values_acc_ptr[col] = rop(new_values_acc_ptr[col], static_cast<opmath_t>(val));
}
for (int64_t i = 0; i < nnz; i++) {
if (need_acc) {
new_values_ptr[i] = static_cast<scalar_t>(new_values_acc_ptr[i]);
} else {
new_values_ptr[i] = new_values_acc_ptr[i];
}
}
});
new_values_acc_ptr[col] = rop(new_values_acc_ptr[col], static_cast<opmath_t>(val));
}
});
if (need_acc) {
new_values.copy_(new_values_acc);
}

[&]() {
index_t* columns_map_ptr = columns_map.data_ptr<index_t>();
scalar_t* values_ptr = values.data_ptr<scalar_t>();
opmath_t* new_values_acc_ptr = new_values_acc.data_ptr<opmath_t>();
scalar_t* new_values_ptr = new_values.data_ptr<scalar_t>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
scalar_t* new_values_ptr = new_values.data_ptr<scalar_t>();

test/test_ops.py Outdated
@parametrize("layout", (torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc, torch.sparse_coo))
@parametrize("layout", (torch.sparse_bsr, torch.sparse_bsc, torch.sparse_coo))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is not needed as it will disable existing tests for CSR and CSC samples.

dense = sparse.to_dense()
for dim in (0, 1):
dense_sum = dense.sum(dim=dim)
sparse_sum = sparse.sum(dim=dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reductions over sparse dimensions of sparse compressed tensors require keepdim=True

def run_test(shape, nnz, index_type):
sparse = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
dense = sparse.to_dense()
for dim in (0, 1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, dim is an integer that does not correspond to a new feature that this PR implements.

Copy link
Collaborator

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @yanbing-j, for this! I did my initial review of the PR and found that there exist more efficient paths to compute the sum of CSC tensors.

I'll address your testing questions a bit later.

@pearu
Copy link
Collaborator

pearu commented May 9, 2023

@yanbing-j:

Hi @pearu , I see your PR #100391 will raise error in 'aten::sum.IntList_out' with arguments from the 'SparseCsr(CPU|CUDA)', which has conflicts to current PR. Therefore, I remove the part of code, remains error of bsr and bsc.

Note that in the end, you should restore test_ops.py as it was.

But I'm confused to how to pass your UT. No matter I raise Error or add warning in bsr and bsc part, self.assertFalse(isinstance(out, type(NotImplemented))) will fail. Could you please give me some advices? And also have a look at this PR? Thank you!

The assertFalse fails because out is None due to
https://github.com/yanbing-j/pytorch/blob/5dd3bb8fa2500b030e0d3b9f58047614a1555d06/aten/src/ATen/native/ReduceOps.cpp#L2180-L2181
You should raise exception as discussed in #99292 (comment) and then update _validate_sample_input_sparse_reduction_sum in torch/testing/_internal/opinfo/definitions/sparse.py accordingly.

Comment on lines 2179 to 2188
if (self.dim() != 2 || keepdim) {
TORCH_CHECK(
false,
"sum expected input with strided, sparse_csr layouts, got layout ",
layout);
} else if (!keepdim) {
TORCH_CHECK(
false,
"torch.empty: Only batched sparse compressed (non-block) tensors are supported");
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't implement support for non-CSR layouts, there is no need to check other parameters. So:

Suggested change
if (self.dim() != 2 || keepdim) {
TORCH_CHECK(
false,
"sum expected input with strided, sparse_csr layouts, got layout ",
layout);
} else if (!keepdim) {
TORCH_CHECK(
false,
"torch.empty: Only batched sparse compressed (non-block) tensors are supported");
}
TORCH_CHECK(
false,
"sum expected input with strided or sparse_csr layout, got layout ",
layout);
}

Copy link
Collaborator Author

@yanbing-j yanbing-j May 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not check the parameters, all the RuntimeError would be sum expected input with strided or sparse_csr layout, got layout, which could not pass the cases which expected error torch.empty: Only batched sparse compressed. WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expected error torch.empty: Only batched sparse compressed.. is just wrong and it should be updated in the sparse.py after the error message is updated.

Comment on lines 2179 to 2188
if (self.dim() != 2 || keepdim) {
TORCH_CHECK(
false,
"sum expected input with strided, sparse_csr layouts, got layout ",
layout);
} else if (!keepdim) {
TORCH_CHECK(
false,
"torch.empty: Only batched sparse compressed (non-block) tensors are supported");
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't implement support for non-CSR layouts, there is no need to check other parameters. So:

Suggested change
if (self.dim() != 2 || keepdim) {
TORCH_CHECK(
false,
"sum expected input with strided, sparse_csr layouts, got layout ",
layout);
} else if (!keepdim) {
TORCH_CHECK(
false,
"torch.empty: Only batched sparse compressed (non-block) tensors are supported");
}
TORCH_CHECK(
false,
"sum expected input with strided or sparse_csr layout, got layout ",
layout);
}

Comment on lines 1178 to 1161
if (need_acc) {
new_values_ptr[row_map_ptr[h]] = static_cast<scalar_t>(res);
} else {
new_values_ptr[row_map_ptr[h]] = res;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
if (need_acc) {
new_values_ptr[row_map_ptr[h]] = static_cast<scalar_t>(res);
} else {
new_values_ptr[row_map_ptr[h]] = res;
}
new_values_ptr[row_map_ptr[h]] = static_cast<scalar_t>(res);

Comment on lines 1140 to 1125
bool need_acc = (values.scalar_type() == kHalf || values.scalar_type() == kBFloat16 || values.scalar_type() == kComplexHalf);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
bool need_acc = (values.scalar_type() == kHalf || values.scalar_type() == kBFloat16 || values.scalar_type() == kComplexHalf);

@@ -1042,23 +1042,37 @@ Tensor reduce_sparse_csr_dim0_cpu_template(const Tensor& sparse, ReductionOp rop
new_crow_indices[1] = nnz;

Tensor new_values = at::empty({nnz}, values.options());
new_values.fill_(rop.identity());
bool need_acc = (values.scalar_type() == kHalf || values.scalar_type() == kBFloat16 || values.scalar_type() == kComplexHalf);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note: in principle, the need_acc information is defined by scalar_t and ReductionOp. This would allow compile-time optimizations because the if-blocks on need_acc can use constexpr. For instance, if ReductionOp would store its template typename as a type member, we could have:

constexpr bool need_acc = !std::is_same<scalar_t, rop::type>::value;

Comment on lines 2808 to 2812
dense_sum = dense.sum(dim=dim)
sparse_sum = sparse.sum(dim=dim, keepdim=True)
is_integral = dtype in integral_types()
self.assertEqual(sparse_sum.to_dense().view(dense_sum.shape)
if not is_integral else sparse_sum.to_dense().to(torch.int64).view(dense_sum.shape), dense_sum)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have the following invariant holding:

sparse.sum(dim=dim, keepdim=True).to_dense() == sparse.to_dense().sum(dim=dim, keepdim=True)

So:

Suggested change
dense_sum = dense.sum(dim=dim)
sparse_sum = sparse.sum(dim=dim, keepdim=True)
is_integral = dtype in integral_types()
self.assertEqual(sparse_sum.to_dense().view(dense_sum.shape)
if not is_integral else sparse_sum.to_dense().to(torch.int64).view(dense_sum.shape), dense_sum)
dense_sum = dense.sum(dim=dim, keepdim=True)
sparse_sum = sparse.sum(dim=dim, keepdim=True)
self.assertEqual(sparse_sum, dense_sum)

(assertEqual with handle the conversion of sparse_sum to a strided tensor).

Comment on lines 2813 to 2819
if dtype in floating_types():
sparse_sum.requires_grad_(True)
sparse_sum.sum().backward()
dense_sum.requires_grad_(True)
dense_sum.sum().backward()
self.assertEqual(sparse_sum.grad.view(dense_sum.shape), torch.ones(dense_sum.shape, dtype=dtype, device=device))
self.assertEqual(sparse_sum.grad.view(dense_sum.shape), dense_sum.grad)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are already covered by test_sum. So:

Suggested change
if dtype in floating_types():
sparse_sum.requires_grad_(True)
sparse_sum.sum().backward()
dense_sum.requires_grad_(True)
dense_sum.sum().backward()
self.assertEqual(sparse_sum.grad.view(dense_sum.shape), torch.ones(dense_sum.shape, dtype=dtype, device=device))
self.assertEqual(sparse_sum.grad.view(dense_sum.shape), dense_sum.grad)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, considering the tests test_reductions and test_reductions_backward in test_sparse.py, there is no need for test_sum_dim_reduce.

Just verify that sample_inputs_sparse_reduction_sum in torch/testing/_internal/opinfo/definitions/sparse.py produce the equivalent samples (I believe it does) and remove test_sum_dim_reduce.

Copy link
Collaborator

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a number of suggestions to simplify the code. Also, test_sum_dim_reduce appears to be unnecessary as the opinfo based test functions cover the samples corresponding to the feature added in this PR.

Copy link
Collaborator

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a number of suggestions to simplify the code. Also, test_sum_dim_reduce appears to be unnecessary as the opinfo based test functions cover the samples corresponding to the feature added in this PR.

@yanbing-j
Copy link
Collaborator Author

yanbing-j commented May 10, 2023

I have one more question, for the test like test_consistency_SparseCSR_sum_cpu_bfloat16, its input dimension would be larger than 2, is this expected? However, _sparse_csr_sum only supports dim == 2. What can we do to pass the UTs?

Add the corresponding if-block to _validate_sample_input_sparse_reduction_sum in sparse.py.

@yanbing-j yanbing-j force-pushed the yanbing/sparse_sum_csr branch 5 times, most recently from f4250fe to 49badf8 Compare May 14, 2023 08:12
@yanbing-j
Copy link
Collaborator Author

@pearu , Thank you so much for the comments. Could you please take another look of this PR?

@yanbing-j yanbing-j requested review from pearu and mingfeima May 14, 2023 11:09
Copy link
Collaborator

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, thanks @yanbing-j! But there are also changes that need to be revised/fixed, in particular

  • sum on CSR tensor appears to ignore user-specified dtype argument. In general, the sum on a tensor with any sparse layout must have exactly the same behavior as the sum on a strided tensor.
  • consistency checks are disabled for CSR and CSC samples. If the consistency checks fail on such samples, we should fix the issue rather than ignore it, or worse of all, disable the consistency checks for all other operations as well as this PR does.
  • The suggestion to rename sum_sparse_csr to sum_sparse_compressed is not applied. Although this PR addresses only CSR tensors, we will use the same function for other sparse compressed layouts in follow-ups. There are no intentions to introduce sum_sparse_csc etc functions, so let's use the correct naming of the function immediately.

Comment on lines 1047 to 1054
auto values_acc_option = values.options();
if (need_acc) {
values_acc_option = values.scalar_type() == kComplexHalf
? values.options().dtype(ScalarType::ComplexFloat)
: values.options().dtype(ScalarType::Float);
}
Tensor new_values_acc =
(need_acc ? at::empty({nnz}, values_acc_option) : new_values);
Copy link
Collaborator

@pearu pearu May 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that need_acc is constexpr, I suggest using:

Suggested change
auto values_acc_option = values.options();
if (need_acc) {
values_acc_option = values.scalar_type() == kComplexHalf
? values.options().dtype(ScalarType::ComplexFloat)
: values.options().dtype(ScalarType::Float);
}
Tensor new_values_acc =
(need_acc ? at::empty({nnz}, values_acc_option) : new_values);
Tensor new_values_acc;
if constexpr (need_acc) {
auto acc_dtype = values.scalar_type() == kComplexHalf ? ScalarType::ComplexFloat : ScalarType::Float;
new_values_acc = at::empty({nnz}, values.options().dtype(acc_dtype));
} else {
new_values_acc = new_values;
}

Comment on lines 1303 to 1306
auto is_integral = at::isIntegralType(dtype_, /*includeBool=*/false);
if (is_integral) {
result = result.to(ScalarType::Long);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this is wrong as it overrides the user-specified dtype value. For strided tensors, we have sum result dtype tests in test/test_reductions.py but it looks like we don't generate sparse samples for sparse reduction tests (test_reductions) in test/test_sparse.py that would have caught this issue.
I think we should have:

Suggested change
auto is_integral = at::isIntegralType(dtype_, /*includeBool=*/false);
if (is_integral) {
result = result.to(ScalarType::Long);
}
auto is_integral = !dtype.has_value() && at::isIntegralType(dtype_, /*includeBool=*/false);
if (is_integral) {
result = result.to(ScalarType::Long);
}

or similar.

Comment on lines 639 to 643
auto is_integral = at::isIntegralType(dtype_, /*includeBool=*/false);
if (is_integral) {
result = result.to(ScalarType::Long);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same note here as above: when a user specifies dtype then the result must have the specified dtype.

Comment on lines 498 to 501
if layout == torch.sparse_csc:
return x
if layout == torch.sparse_csr and (x.dtype == torch.bool or x.dtype == torch.complex32) and op.name == "sum":
return x
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change? This disables the consistency checks for CSC and CSR tensors and we should not do that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes are due to _validate_sample_input_sparse_reduction_sum make the limit to CSC tensor and CSR tensor with bool or complex32.
Therefore the sample will not be added into samples in test_sparse_csr.py, and raise Error of Expected at least one 2 or higher D tensor in samples.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT? @pearu

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

samples that don't pass _validate_sample_input_sparse_reduction_sum (read: such samples are mapped to ErrorInput instances) are not included in the set of samples that op.sample_inputs(device, dtype) generates. So, this change is unnecessary (it ought to be).

In addition, notice that test_consistency is used for many operations, not just sum. So, any changes to this test must not affect testing other operations. This PR just does this: it disables consistency tests for all non-sum operations with respect to CSC inputs.

So, please undo this change here and report what problems exist for the sum operation. If there are any, these problems ought to be tackled in sparse.py, not here.

Copy link
Collaborator Author

@yanbing-j yanbing-j May 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification, @pearu . I have removed the unnecessary change in test_sparse_csr.py.

Now the problem is, in sparse.py, CSC inputs and CSR inputs with bool and complex32 data type will generate ErrorInput, which is expected in this PR. And the samples will become empty, and raise Error of Expected at least one 2 or higher D tensor in samples. How can we do in sparse.py to fix the error? Thank you!

            if validate_sample_input_sparse(op, sparse_sample, check_validate=False) is not sparse_sample:
                # that is, the validation returns the sparse sample
                # wrapped within ErrorInput instance
                continue
            samples.append((sample, sparse_sample))

        # Fail early to prevent silent success with this test
        if len(samples) == 0:
            raise ValueError("Expected at least one 2 or higher D tensor in samples.")

dim.has_value(), "dim has no value, cannot be used in sum.dim_IntList");
auto layout = self.layout();
TORCH_CHECK(layout == kSparseCsr,
"sum expected input with strided, sparse_csr layouts, got layout ", layout)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update this to "Currently the only compressed sparse format supported for sum.dim_IntList is CSR, but got layout ", layout. It's also not true that only strided and CSR are supported. We also have a kernel registered under the Sparse dispatch key which means support for COO.

Copy link
Contributor

@cpuhrsch cpuhrsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a lot of duplicated code around

  if constexpr (need_acc) {
    auto acc_dtype = CppTypeToScalarType<acc_t>::value;
    new_values_acc = at::empty({}, values.options().dtype(acc_dtype));
    new_values = is_integral ? new_values_acc : at::empty({}, values.options());
  } else {
    new_values_acc = new_values = at::empty({}, values.options());;
  }

Can we abstract that into a helper within a header or such and unify that logic into a single place? We don't want to have this diverge between devices or formats or input settings.

@yanbing-j yanbing-j force-pushed the yanbing/sparse_sum_csr branch 2 times, most recently from b3524bf to 06f1756 Compare July 19, 2023 09:26
@yanbing-j yanbing-j requested a review from cpuhrsch July 19, 2023 09:27
@@ -366,5 +367,36 @@ inline bool only_sparse_compressed_add_trivial_cases(
});
}

inline Tensor to_type(Tensor input, ScalarType dtype) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you check whether we already have support for this conversion? We have the torch.Tensor.to operator.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refer to the comment #99292 (comment).

// of float, while in CPU, double is the accumulate type of float.
using acc_t = at::acc_type<scalar_t, true>;
constexpr bool need_acc = !std::is_same<scalar_t, acc_t>::value;
bool is_integral = at::isIntegralType(values.scalar_type(), /*includeBool=*/true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is repeated 4 times. The intent of my higher level comment was not around a single block of particular code. It was to abstract away repeated code.

Copy link
Contributor

@cpuhrsch cpuhrsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for addressing my comments. I think we can simplify this a bit more by abstracting away more code and checking for existing functionality.

@yanbing-j yanbing-j force-pushed the yanbing/sparse_sum_csr branch from 06f1756 to 879bd2c Compare July 20, 2023 13:00
@yanbing-j yanbing-j requested a review from cpuhrsch July 20, 2023 13:53
}

template <typename acc_t, typename scalar_t>
inline void create_acc_buffer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can also include the construction and resize of Tensor new_values, new_values_acc; into this helper function.

@yanbing-j yanbing-j force-pushed the yanbing/sparse_sum_csr branch 2 times, most recently from 7eaae75 to 479b4ec Compare July 21, 2023 08:48
@@ -1124,9 +1142,12 @@ Tensor reduce_sparse_csr_dim1_cpu_template(const Tensor& sparse, ReductionOp rop
new_col_indices.resize_(nnz);
new_col_indices.fill_(index_t(0));
new_values.resize_(nnz);
if (!new_values_acc.is_same(new_values)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for addressing the comments. Is there something you can do to abstract away this line as well? It's repeat 5 times.

Simplify the dispatch

Add UT

Fix the bug of index_type in reduce_sparse_csr_dim0_cpu_template

Use opmath_t as reduction type

Fix CI failures

The failures are from:
1. SparseCsrCPU includes csc, bsr, bsc, except for csr. csr is only supported
for now, csc can be easily converted to csr, while bsr and bsc are not.
2. This PR has conflict with pytorch#100391.

Remove the change in test_sparse_csr.py

Remove AT_DISPATCH_INDEX_TYPES based on comments

Remove to(torch.int64) explicitly and change input data type to torch.int64

Support integral return value in reduce_sparse_csr_cpu_template

Refactor according to comments

Abstract to_type in sparse_csr

Update based on comments
@yanbing-j yanbing-j force-pushed the yanbing/sparse_sum_csr branch from 479b4ec to e45db7a Compare July 23, 2023 04:14
@yanbing-j yanbing-j requested a review from cpuhrsch July 23, 2023 05:16
@cpuhrsch
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pearu added a commit that referenced this pull request Sep 4, 2023
This PR addresses the sparse tensors part of #99655

The PR introduces the following utility functions:
- `at::sparse_csr::alias_with_values(a_sparse_compressed_tensor, new_values)`
- `at::sparse::alias_with_values(a_sparse_tensor, new_values)`

These functions return a wrapper of a sparse tensor with new specified values that allow introducing alias support for sparse tensors and more (e.g. the most efficient way to resolve #99292 (comment) is to use `at::sparse_csr::alias_with_values(self, self.values().to(dtype))` as a replacement of `self.to(dtype)` to avoid the unnecessary copy of indices).




[ghstack-poisoned]
pearu added a commit that referenced this pull request Sep 4, 2023
This PR addresses the sparse tensors part of #99655

The PR introduces the following utility functions:
- `at::sparse_csr::alias_with_values(a_sparse_compressed_tensor, new_values)`
- `at::sparse::alias_with_values(a_sparse_tensor, new_values)`

These functions return a wrapper of a sparse tensor with new specified values that allow introducing alias support for sparse tensors and more (e.g. the most efficient way to resolve #99292 (comment) is to use `at::sparse_csr::alias_with_values(self, self.values().to(dtype))` as a replacement of `self.to(dtype)` to avoid the unnecessary copy of indices).




[ghstack-poisoned]
pearu added a commit that referenced this pull request Sep 4, 2023
This PR addresses the sparse tensors part of #99655

The PR introduces the following utility functions:
- `at::sparse_csr::alias_with_values(a_sparse_compressed_tensor, new_values)`
- `at::sparse::alias_with_values(a_sparse_tensor, new_values)`

These functions return a wrapper of a sparse tensor with new specified values that allow introducing alias support for sparse tensors and more (e.g. the most efficient way to resolve #99292 (comment) is to use `at::sparse_csr::alias_with_values(self, self.values().to(dtype))` as a replacement of `self.to(dtype)` to avoid the unnecessary copy of indices).




[ghstack-poisoned]
pearu added a commit that referenced this pull request Sep 4, 2023
This PR addresses the sparse tensors part of #99655

The PR introduces the following utility functions:
- `at::sparse_csr::alias_with_values(a_sparse_compressed_tensor, new_values)`
- `at::sparse::alias_with_values(a_sparse_tensor, new_values)`

These functions return a wrapper of a sparse tensor with new specified values that allow introducing alias support for sparse tensors and more (e.g. the most efficient way to resolve #99292 (comment) is to use `at::sparse_csr::alias_with_values(self, self.values().to(dtype))` as a replacement of `self.to(dtype)` to avoid the unnecessary copy of indices).




[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request intel This tag is for PR from Intel Merged open source release notes: sparse release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

9 participants