Skip to content

Commit fa3b7b9

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
Removed the no longer necessary jax-aval-named-shape deprecation
PiperOrigin-RevId: 810441482
1 parent 6a109ea commit fa3b7b9

File tree

2 files changed

+0
-12
lines changed

2 files changed

+0
-12
lines changed

jax/_src/array.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from jax._src import basearray
2828
from jax._src import config
2929
from jax._src import core
30-
from jax._src import deprecations
3130
from jax._src import dispatch
3231
from jax._src import dtypes
3332
from jax._src import errors
@@ -124,16 +123,6 @@ def _reconstruct_array(fun, args, arr_state, aval_state):
124123
np_value = fun(*args)
125124
np_value.__setstate__(arr_state)
126125
jnp_value = api.device_put(np_value)
127-
# TODO(slebedev): Remove this branch after December 10th 2024.
128-
if "named_shape" in aval_state:
129-
deprecations.warn(
130-
"jax-aval-named-shape",
131-
"Pickled array contains an aval with a named_shape attribute. This is"
132-
" deprecated and the code path supporting such avals will be removed."
133-
" Please re-pickle the array.",
134-
stacklevel=2,
135-
)
136-
del aval_state["named_shape"]
137126
jnp_value.aval = jnp_value.aval.update(**aval_state)
138127
return jnp_value
139128

jax/_src/deprecations.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None:
124124
# Register a number of deprecations: we do this here to ensure they're
125125
# always registered by the time `accelerate` and `is_acelerated` are called.
126126
register('default-dtype-bits-config')
127-
register('jax-aval-named-shape')
128127
register('jax-lax-dot-positional-args')
129128
register('jax-nn-one-hot-float-input')
130129
register("jax-numpy-astype-complex-to-real")

0 commit comments

Comments
 (0)