-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Open
Labels
NVIDIA GPUIssues specific to NVIDIA GPUsIssues specific to NVIDIA GPUsbugSomething isn't workingSomething isn't working
Description
Description
I am trying to run a computation including jax.numpy.linalg.eig
on a GPU.
Since eig
is not implemented with GPU backend, I am trying to use the experimental host_callback
module to force the calculation to take place on the CPU. Unfortunately, it seems like an attempt is still made to run this on the GPU, which fails.
This code snippet reproduces the error, for example in colab with a GPU runtime.
def _eig_host(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Performs an `eig` solve on the host (CPU)."""
eigenvalues_shape = jax.ShapeDtypeStruct(matrix.shape[:-1], complex)
eigenvectors_shape = jax.ShapeDtypeStruct(matrix.shape, complex)
return hcb.call(
jnp.linalg.eig,
matrix.astype(complex),
result_shape=(eigenvalues_shape, eigenvectors_shape),
)
x = jax.random.normal(jax.random.PRNGKey(0), (3, 3))
_eig_host(x) # This fails, because `eig` is not implemented for GPU.
What jax/jaxlib version are you using?
0.3.23
Which accelerator(s) are you using?
colab GPU
Additional system info
colab
NVIDIA GPU info
No response
Metadata
Metadata
Assignees
Labels
NVIDIA GPUIssues specific to NVIDIA GPUsIssues specific to NVIDIA GPUsbugSomething isn't workingSomething isn't working