Skip to content

Ops involving only CPU tensors run on device, not CPU #2905

@johnpjf

Description

@johnpjf

In the code below running jax op that has CPU backed inputs runs on device, (which for my use case OOMs):

def CpuArray(numpy_array):
  return jax.device_put(numpy_array, device=jax.devices(backend='cpu')[0])

data = CpuArray(np.ones([100, 4], dtype=np.float32))
slice_indexes = CpuArray(np.zeros([100], dtype=np.int32))
# Summing the data runs on TPU.
summed_data = jnp.sum(data[slice_indexes]) 
# So does slicing it.
sliced_data = data[100]

print("data device", data.device_buffer.device())
print("sliced indexes", slice_indexes.device_buffer.device())
print("sliced data", sliced_data.device_buffer.device())
print("summed data", summed_data.device_buffer.device())

Prints:

data device cpu:0
sliced indexes device cpu:0
Both output arrays are on TPU:
sliced data device TPU_0(host=0,(0,0,0,0))
summed data device TPU_0(host=0,(0,0,0,0))

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions