Skip to content

Jax support for sparse matrices #6544

@GJBoth

Description

@GJBoth

Hi all,

I was discussing with @jakevdp in #6466 about adding support for sparse matrices in Jax. He already did quite some groundwork, but there's a lot of open questions on the what and how. Quoting Jake here:

  • 2D CSC/CSR/COO may not be the abstraction we want in the end... pydata-sparse has N-dimensional sparse arrays implemented using COO or a multi-dimensional generalization of CSR. Similarly taco has routines for general N-dimensional sparse computation. That route may be better, particularly for some deep learning workflows that e.g. have 4D batched arrays of sparse image sequences.
  • XLA scatter/gather implementations only go so far... any operations we implement should have efficient lowerings on as many backends as possible - I've started with cusparse on GPU because the low-level routines are already available in our GPU build process. This means that every non-trivial operation should be either implemented as a primitive or composed of other primitives
  • JAX is about composability, and in particular pmap/vmap, jit, and grad are central to the API. I want to spend some time exploring how we can best implement gradients of sparse operations or batched sparse operations, and also think about how ongoing work on dynamic shapes can be utilized to make jit of sparse operations more efficient and usable. With all that up in the air I want to keep the design minimal for now in case we need to re-think our approach based on what we find
  • all of the above should be motivated by real-world workflows that people have. I'm starting to gather a few example workflows so we can have that in mind.

So I set out to compare the implementations in pytorch and tensorflow to get a feel for what's out there, and considering the points above. Specifically, we're interested in how they implement it low-level: do they secretly turn everything dense under the surface, or are they given the full treatment with specific ops?

I'm summarizing what I found in the table below - TL;DR Pytorch offers better support. Since documentation and support are all over the place, in this notebook you can find a bunch of common operations with sparse tensor to see if they are supported. It's been a while since I did tensorflow, so let me know if I made a mistake!

Feature Pytorch Tensorflow
Documentation here here
Formats COO COO
Supported ops variations of matmul sums, min, matmul, element-wise ops through tf.sparse.map_values
Specialized ops Probably1 Seems so 2
Grad of sparse3 Limited to specific matmul No.
Sparse grad4 Yes No.
Dimensions 2D, Hybrid 2D - tensor supports higher-D, but operations not
Uncoalesced5allowed Yes No mention
Sparse - sparse matmul No No
Extra goodies Has an adam optimizer for sparse tensor Nope

To do

  • Dig in the source code a little further (although if someone who actually is familiar with the code could chime in, that'd be great.)
  • Check how pydata-sparse implement n-D sparse arrays.

Proposal:

  • Get a very minimal differentiable sparse matmul working for COO format using XLA ops and one GPU-specific backend in cusparse (some help and guidance here would be nice)

I figured we can use this issue to track progress and discuss, so ... thoughts?

Footnotes

  1. They mention Torchcudasparse is now defunct here but I can't find any other mention in the source code.

  2. Sparse ops are imported from gen_sparse_ops.py, which doesnt exist in the repo?

  3. Meaning gradient w.r.t to the sparse tensor

  4. Meaning calculating the gradient without turning it into a dense tensor

  5. Uncoalesced means duplicate entries in the coordinates - total value of the element is then sum of duplicate elements.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions