-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Description
While running this in Google Colab I get the following error: I am using the pro version of Google Collab.
XlaRuntimeError Traceback (most recent call last)
in <cell line: 9>()
9 for i in trange(max(n_predictions // jax.device_count(), 1)):
10 # get a new key
---> 11 key, subkey = jax.random.split(key)
12 # generate images
13 encoded_images = p_generate(
10 frames
[... skipping hidden 2 frame]
/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py in _execute_compiled(name, compiled, input_handler, output_buffer_counts, result_handler, has_unordered_effects, ordered_effects, kept_var_idx, has_host_callbacks, *args)
893 runtime_token = None
894 else:
--> 895 out_flat = compiled.execute(in_flat)
896 check_special(name, out_flat)
897 out_bufs = unflatten(out_flat, output_buffer_counts)
XlaRuntimeError: INTERNAL: CustomCall failed: jaxlib/gpu/prng_kernels.cc:33: operation gpuGetLastError() failed: cudaGetErrorString symbol not found