Skip to content

jax.device_put has complexity superlinear in the number of arguments. #947

@hawkinsp

Description

@hawkinsp

Reproduction on CPU:

import jax
import numpy as onp

In [9]: %time x = jax.device_put([onp.random.randn(10,5) for _ in range(100)])
CPU times: user 1.45 s, sys: 7.8 ms, total: 1.46 s
Wall time: 1.45 s

In [10]: %time x = jax.device_put([onp.random.randn(10,5) for _ in range(500)])
CPU times: user 24.8 s, sys: 0 ns, total: 24.8 s
Wall time: 24.8 s

In [11]: %time x = jax.device_put([onp.random.randn(10,5) for _ in range(700)])
CPU times: user 45.2 s, sys: 190 ms, total: 45.4 s
Wall time: 45.4 s

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions