Skip to content

Commit d179184

Browse files
author
Flax Authors
committed
Merge pull request #5149 from IvyZX:release
PiperOrigin-RevId: 846407227
2 parents 9213cfd + 6d36623 commit d179184

File tree

10 files changed

+18
-18
lines changed

10 files changed

+18
-18
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ To cite this repository:
158158
author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
159159
title = {{F}lax: A neural network library and ecosystem for {JAX}},
160160
url = {http://github.com/google/flax},
161-
version = {0.12.1},
161+
version = {0.12.2},
162162
year = {2024},
163163
}
164164
```

docs/guides/parallel_training/flax_on_pjit.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,7 @@
669669
" state = state.apply_gradients(grads=grads)\n",
670670
" return state\n",
671671
"\n",
672-
"with mesh:\n",
672+
"with jax.set_mesh(mesh):\n",
673673
" new_state = train_step(initialized_state, x)"
674674
]
675675
},
@@ -824,7 +824,7 @@
824824
"def apply_fn(state, x):\n",
825825
" return state.apply_fn({'params': state.params}, x)\n",
826826
"\n",
827-
"with mesh:\n",
827+
"with jax.set_mesh(mesh):\n",
828828
" y = apply_fn(new_state, x)\n",
829829
"print(type(y))\n",
830830
"print(y.dtype)\n",
@@ -861,7 +861,7 @@
861861
" jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)\n",
862862
" return xs\n",
863863
"\n",
864-
"with mesh:\n",
864+
"with jax.set_mesh(mesh):\n",
865865
" new_state = block_all(train_step(initialized_state, x))"
866866
]
867867
},

docs/guides/parallel_training/flax_on_pjit.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def train_step(state, x):
319319
state = state.apply_gradients(grads=grads)
320320
return state
321321
322-
with mesh:
322+
with jax.set_mesh(mesh):
323323
new_state = train_step(initialized_state, x)
324324
```
325325

@@ -338,7 +338,7 @@ Then, create a compiled inference step. Note that the output is also sharded alo
338338
def apply_fn(state, x):
339339
return state.apply_fn({'params': state.params}, x)
340340
341-
with mesh:
341+
with jax.set_mesh(mesh):
342342
y = apply_fn(new_state, x)
343343
print(type(y))
344344
print(y.dtype)
@@ -357,7 +357,7 @@ def block_all(xs):
357357
jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)
358358
return xs
359359
360-
with mesh:
360+
with jax.set_mesh(mesh):
361361
new_state = block_all(train_step(initialized_state, x))
362362
```
363363

docs_nnx/flip/4844-var-eager-sharding.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def create_sharded_model():
3333
return model
3434

3535
mesh = jax.make_mesh(((2, 4)), ("data", "model"))
36-
with mesh:
36+
with jax.set_mesh(mesh):
3737
sharded_model = create_sharded_model()
3838
```
3939

docs_nnx/guides/bridge_guide.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,7 @@
689689
"mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)),\n",
690690
" axis_names=('in', 'out'))\n",
691691
"x = jax.random.normal(jax.random.key(42), (4, 32))\n",
692-
"with mesh:\n",
692+
"with jax.set_mesh(mesh):\n",
693693
" model = create_sharded_nnx_module(x)\n",
694694
"\n",
695695
"print(type(model.w)) # `nnx.Param`\n",
@@ -763,7 +763,7 @@
763763
" nn.get_partition_spec(variables))\n",
764764
" return sharded_vars\n",
765765
"\n",
766-
"with mesh:\n",
766+
"with jax.set_mesh(mesh):\n",
767767
" variables = create_sharded_variables(jax.random.key(0), x)\n",
768768
"\n",
769769
"# The underlying JAX array is sharded across the 2x4 mesh\n",

docs_nnx/guides/bridge_guide.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ print(f'We have {len(jax.devices())} fake JAX devices now to partition this mode
370370
mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)),
371371
axis_names=('in', 'out'))
372372
x = jax.random.normal(jax.random.key(42), (4, 32))
373-
with mesh:
373+
with jax.set_mesh(mesh):
374374
model = create_sharded_nnx_module(x)
375375
376376
print(type(model.w)) # `nnx.Param`
@@ -422,7 +422,7 @@ def create_sharded_variables(key, x):
422422
nn.get_partition_spec(variables))
423423
return sharded_vars
424424
425-
with mesh:
425+
with jax.set_mesh(mesh):
426426
variables = create_sharded_variables(jax.random.key(0), x)
427427
428428
# The underlying JAX array is sharded across the 2x4 mesh

examples/lm1b/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def setup_initial_state(model, tx, config, rng, mesh):
168168
state_logical_annotations = nn.get_partition_spec(abstract_state)
169169

170170
# Initialization
171-
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
171+
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
172172
state_mesh_annotations = nn.logical_to_mesh_sharding(
173173
state_logical_annotations, mesh, config.logical_axis_rules
174174
)

flax/core/meta.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from flax import errors, struct
3131
from flax.typing import LogicalNames
3232
import jax
33-
from jax.interpreters import pxla
3433

3534
A = TypeVar('A')
3635
B = TypeVar('B')
@@ -182,9 +181,7 @@ def inner_update(c, v):
182181
def get_global_mesh() -> jax.sharding.AbstractMesh | jax.sharding.Mesh | None:
183182
mesh = jax.sharding.get_abstract_mesh()
184183
if mesh.empty:
185-
mesh = pxla.thread_resources.env.physical_mesh
186-
if mesh.empty:
187-
return None
184+
return None
188185
return mesh
189186

190187

flax/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
# limitations under the License.
1414

1515
"""Current Flax version at head on Github."""
16-
__version__ = '0.12.1'
16+
__version__ = '0.12.2'

tests/nnx/module_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,9 @@ def __init__(self, din, dout, *, rngs: nnx.Rngs):
693693
nnx.set_mode(block, deterministic=True, use_running_average=True, unknown=True)
694694

695695
def test_cloud_pickle(self):
696+
import platform
697+
if platform.python_version().startswith('3.11'):
698+
self.skipTest("Cloudpickle cannot pickle PRNGKeyArray on python 3.11")
696699
class Model(nnx.Module):
697700
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
698701
self.linear = nnx.Linear(din, dmid, rngs=rngs)

0 commit comments

Comments
 (0)