-
Notifications
You must be signed in to change notification settings - Fork 73
Open
Labels
Description
The following code in multiple cells in the JupyterLab crashes randomly when executed:
import jax
import flax
from flax import nnx
import orbax.checkpoint as ocp
print(jax.__version__)
print(flax.__version__)
print(ocp.__version__)
path = ocp.test_utils.erase_and_create_empty("/tmp/repro-crash-model/")
options = ocp.CheckpointManagerOptions(max_to_keep=2)
mngr = ocp.CheckpointManager(path, options=options)
def save_model(epoch):
state = nnx.state(model)
# We should convert PRNGKeyArray to the old format for Dropout layers
# https://github.com/google/flax/issues/4231
def get_key_data(x):
if isinstance(x, jax._src.prng.PRNGKeyArray):
if isinstance(x.dtype, jax._src.prng.KeyTy):
return jax.random.key_data(x)
return x
serializable_state = jax.tree.map(get_key_data, state)
mngr.save(epoch, args=ocp.args.StandardSave(serializable_state))
mngr.wait_until_finished()
class TestModel(nnx.Module):
def __init__(self, n1, n2, n3, *, rngs = nnx.Rngs(0)):
self.m1 = nnx.Linear(n1, n2, rngs=rngs)
self.dropout = nnx.Dropout(0.1, rngs=rngs)
self.m1 = nnx.Linear(n2, n3, rngs=rngs)
def __call__(self, x):
x = self.m1(x)
x = self.dropout(x)
x = self.m2(x)
return xHere is the ipynb to reproduce the crash: https://gist.github.com/vfdev-5/da7f01b9c0c2948046b227a32db2cb2b