-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Description
Hi all,
I was discussing with @jakevdp in #6466 about adding support for sparse matrices in Jax. He already did quite some groundwork, but there's a lot of open questions on the what and how. Quoting Jake here:
- 2D CSC/CSR/COO may not be the abstraction we want in the end... pydata-sparse has N-dimensional sparse arrays implemented using COO or a multi-dimensional generalization of CSR. Similarly taco has routines for general N-dimensional sparse computation. That route may be better, particularly for some deep learning workflows that e.g. have 4D batched arrays of sparse image sequences.
- XLA scatter/gather implementations only go so far... any operations we implement should have efficient lowerings on as many backends as possible - I've started with cusparse on GPU because the low-level routines are already available in our GPU build process. This means that every non-trivial operation should be either implemented as a primitive or composed of other primitives
- JAX is about composability, and in particular pmap/vmap, jit, and grad are central to the API. I want to spend some time exploring how we can best implement gradients of sparse operations or batched sparse operations, and also think about how ongoing work on dynamic shapes can be utilized to make jit of sparse operations more efficient and usable. With all that up in the air I want to keep the design minimal for now in case we need to re-think our approach based on what we find
- all of the above should be motivated by real-world workflows that people have. I'm starting to gather a few example workflows so we can have that in mind.
So I set out to compare the implementations in pytorch and tensorflow to get a feel for what's out there, and considering the points above. Specifically, we're interested in how they implement it low-level: do they secretly turn everything dense under the surface, or are they given the full treatment with specific ops?
I'm summarizing what I found in the table below - TL;DR Pytorch offers better support. Since documentation and support are all over the place, in this notebook you can find a bunch of common operations with sparse tensor to see if they are supported. It's been a while since I did tensorflow, so let me know if I made a mistake!
Feature | Pytorch | Tensorflow |
---|---|---|
Documentation | here | here |
Formats | COO | COO |
Supported ops | variations of matmul | sums, min, matmul, element-wise ops through tf.sparse.map_values |
Specialized ops | Probably1 | Seems so 2 |
Grad of sparse3 | Limited to specific matmul | No. |
Sparse grad4 | Yes | No. |
Dimensions | 2D, Hybrid | 2D - tensor supports higher-D, but operations not |
Uncoalesced5allowed | Yes | No mention |
Sparse - sparse matmul | No | No |
Extra goodies | Has an adam optimizer for sparse tensor | Nope |
To do
- Dig in the source code a little further (although if someone who actually is familiar with the code could chime in, that'd be great.)
- Check how pydata-sparse implement n-D sparse arrays.
Proposal:
- Get a very minimal differentiable sparse matmul working for COO format using XLA ops and one GPU-specific backend in cusparse (some help and guidance here would be nice)
I figured we can use this issue to track progress and discuss, so ... thoughts?
Footnotes
-
They mention Torchcudasparse is now defunct here but I can't find any other mention in the source code. ↩
-
Sparse ops are imported from gen_sparse_ops.py, which doesnt exist in the repo? ↩
-
Meaning gradient w.r.t to the sparse tensor ↩
-
Meaning calculating the gradient without turning it into a dense tensor ↩
-
Uncoalesced means duplicate entries in the coordinates - total value of the element is then sum of duplicate elements. ↩