Skip to content

2x slowdown using pmap with DeviceArrays on TPU. #2871

@jmgilmer

Description

@jmgilmer

Recently we've observed 2x slower training when using pmap on jax arrays. Casting the data as numpy arrays is a current workaround.

The training loop looks roughly like

update_params = jax.pmap(update)
for data in data_gen():
  # Casting to onp fixes the slowness
  # data = onp.array(data)
  params = update_params(data, params)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions