Skip to content

Host callback attempts to perform computation on GPU #13046

@mfschubert

Description

@mfschubert

Description

I am trying to run a computation including jax.numpy.linalg.eig on a GPU.

Since eig is not implemented with GPU backend, I am trying to use the experimental host_callback module to force the calculation to take place on the CPU. Unfortunately, it seems like an attempt is still made to run this on the GPU, which fails.

This code snippet reproduces the error, for example in colab with a GPU runtime.

def _eig_host(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Performs an `eig` solve on the host (CPU)."""
    eigenvalues_shape = jax.ShapeDtypeStruct(matrix.shape[:-1], complex)
    eigenvectors_shape = jax.ShapeDtypeStruct(matrix.shape, complex)
    return hcb.call(
        jnp.linalg.eig,
        matrix.astype(complex),
        result_shape=(eigenvalues_shape, eigenvectors_shape),
    )

x = jax.random.normal(jax.random.PRNGKey(0), (3, 3))
_eig_host(x)  # This fails, because `eig` is not implemented for GPU.

What jax/jaxlib version are you using?

0.3.23

Which accelerator(s) are you using?

colab GPU

Additional system info

colab

NVIDIA GPU info

No response

Metadata

Metadata

Assignees

Labels

NVIDIA GPUIssues specific to NVIDIA GPUsbugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions