Skip to content

Integrating complex tensors #755

@PhilippPelz

Description

@PhilippPelz

New description from @ezyang:

Work is in progress at https://github.com/Roger-luo/pytorch-complex

Organizational principles

  • Complex tensor support is important to PyTorch, and we will accept patches to core which add small amounts of code to make adding complex support.
  • Adding complex involves writing a lot of new kernels and code: we'd like this code to initially live out of repo, so it is easier for people to iterate quickly on them without having to go through the PyTorch main code review process. We will NOT commit to reviewing large new kernels in the short term, but eventually we would like all the kernels to come back to PyTorch.
  • The external library will be buildable separately from PyTorch, so you will be able to maintain it as a separate repository without having to merge with PyTorch (and deal with loads of merge conflicts).
    • PyTorch may occasionally make breaking changes in C++ API; if you bring these to our attention we will do our utmost to help solve these problems.
  • The hooks needed for this will NOT ship with PyTorch 1.0, but they will ship with a released version of PyTorch in the not too distant future.

How will I work on complex kernels?

Here is what the workflow will look like in the steady state.

PyTorch will natively contain APIs for referring to the complex dtype, but they won't do anything by default. PyTorch defines torch.complex64 and torch.complex128 referring to complex tensors. However, if you try to construct a tensor this way, by default, PyTorch will error:

>>> torch.zeros({2,2}, dtype=torch.complex64)
RuntimeError: complex64 not supported by PyTorch

@ezyang provided a patch which adds these dtypes to PyTorch. #11173

In the mid-term, we will merge support for basic functionality (like allocating a tensor of zeros) to be supported by PyTorch natively. A reasonable proxy for what support is “basic” is PyTorch's native support for CPU half tensors (which are extremely impoverished).

PyTorch publishes an interface for registering an implementation of complex tensors. The implementation inherits from the TypeDefault class (#11013) and will override methods on this class to define implementations of functions for which we have complex implementations. It will look something like this:

struct CPUComplexFloatType final : public TypeDefault {
  virtual Tensor add(const Tensor & self, const Tensor & other, Scalar alpha=1) const override {
    // Your implementation of add for complex tensors
  }
  // ...
}

This class will override exactly the types which are supported for complex; all other implementations are provided by TypeDefault and will error by default.

There will be a canonical listing of methods supported on Type (the overall interface) as an autogenerated file that is checked into the PyTorch source repository; we'll communicate API changes by diffs to this file. In general, the methods are in one-to-one correspondence with their corresponding names in the PyTorch frontend.

In general, when you use an operation which you haven't implemented yet,

WARNING: We intend to refactor Type away into a new system that also supports open registration of new operations (this obviously doesn't work if you have a single superclass that defines all the methods you might possibly want to support). Thus, try not to get too tied to the particular implementation strategy of writing Type as a subclass.

To publish new, complex only operations, you will use the C++ extension API. The C++ extension API is documented at https://pytorch.org/tutorials/advanced/cpp_extension.html Essentially, you can write a C++ function like:

at::Tensor imag(at::Tensor z) {
  ...
}

And then the C++ extension API will generate a Python binding so that you invoke this function from Python.

Some operations will be “easy” to integrate into PyTorch as it exists today. For example, for implementation of binary operations, it probably makes more sense to extend add_kernel in BinaryOpsKernel.cpp so that it dispatches over complex types (and then you get it for free, because std::complex implements addition). As long as these patches are small and self-contained, we promise to merge them on a timely basis.

It should ALWAYS be possible to unblock, by just writing an override on Type instead of using existing infrastructure, and doing liberal copy pasting. But let's avoid it when it's easy!

Autograd. As long as you're working on operations which already have derivative formulas defined for them, you will “automatically” get autograd support, as long as you implement complex support for all the constituent functions which are invoked in the backwards implementation from derivatives.yaml.

In some cases, we may need to adjust autograd formulas so that they work for complex numbers; e.g., the gradient of 'abs' isn't 'grad . self.sign()'. In these cases, all we need to do is upstream fix of changing the autograd formula of 'abs' to 'abs_backward', which is a function that can be overridden.

For general complex valued back propagation, there are some references:

  1. Akira’s “Complex Valued Neural Networks”.
  2. https://giggleliu.github.io/2018/02/01/complex_bp.html

Generally, we won't need to modify the autograd since in most cases we only calculate the derivatives of a real-valued function (the loss).

Work plan

Many of the necessary pieces are in place today, but they are not put together in an end-to-end way. Here is what needs to be done.

Short term integration plan. These operations are “easy” to implement, and so we should mainline them in PyTorch as soon as possible.

Kernel implementation:

TODO: Generate a list based on https://github.com/Roger-luo/TH/blob/master/ChangeLog.md

Other complex related tasks:

Historical issue content

Original comment from @PhilippPelz

I was wondering if there is interest in incorporating complex tensors into pytorch.
For CPU support there is ztorch and I have written z-cutorch ( https://github.com/PhilippPelz/z-cutorch ) a while ago. It is a fork off cutorch before the refactoring for CudaHalfTensor (don't have the hardware yet).
If it's not too much work, I would like to slowly integrate it with pytorch. I am using matplotlib for plotting via fb.ptyhon and it turns out a huge pain every time I reinstall my system (compiling all the dependencies), plus it seems pytorch will work under Windows soon, which one of my experiment PCs runs on.
I would also need complex gradients, so I would sooner or later touch autograd as well.
While tf supports complex tensors per se, it seems many ops don't support it yet (tensorflow/tensorflow#2255), plus it seems a bit heavyweight for my purposes.

Maybe someone could say a few words how and where to start with this, if it's a welcome idea.

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureA request for a proper, new feature.module: complexRelated to complex number support in PyTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions