-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
In the code below running jax op that has CPU backed inputs runs on device, (which for my use case OOMs):
def CpuArray(numpy_array):
return jax.device_put(numpy_array, device=jax.devices(backend='cpu')[0])
data = CpuArray(np.ones([100, 4], dtype=np.float32))
slice_indexes = CpuArray(np.zeros([100], dtype=np.int32))
# Summing the data runs on TPU.
summed_data = jnp.sum(data[slice_indexes])
# So does slicing it.
sliced_data = data[100]
print("data device", data.device_buffer.device())
print("sliced indexes", slice_indexes.device_buffer.device())
print("sliced data", sliced_data.device_buffer.device())
print("summed data", summed_data.device_buffer.device())
Prints:
data device cpu:0
sliced indexes device cpu:0
Both output arrays are on TPU:
sliced data device TPU_0(host=0,(0,0,0,0))
summed data device TPU_0(host=0,(0,0,0,0))
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working