Skip to content

CheckpointManager randomly crashes the kernel of the JupyterLab #1408

@vfdev-5

Description

@vfdev-5

The following code in multiple cells in the JupyterLab crashes randomly when executed:

import jax
import flax
from flax import nnx
import orbax.checkpoint as ocp

print(jax.__version__)
print(flax.__version__)
print(ocp.__version__)

path = ocp.test_utils.erase_and_create_empty("/tmp/repro-crash-model/")
options = ocp.CheckpointManagerOptions(max_to_keep=2)
mngr = ocp.CheckpointManager(path, options=options)


def save_model(epoch):
    state = nnx.state(model)
    # We should convert PRNGKeyArray to the old format for Dropout layers
    # https://github.com/google/flax/issues/4231
    def get_key_data(x):
        if isinstance(x, jax._src.prng.PRNGKeyArray):
            if isinstance(x.dtype, jax._src.prng.KeyTy):
                return jax.random.key_data(x)
        return x

    serializable_state = jax.tree.map(get_key_data, state)
    mngr.save(epoch, args=ocp.args.StandardSave(serializable_state))
    mngr.wait_until_finished()

class TestModel(nnx.Module):
    def __init__(self, n1, n2, n3, *, rngs = nnx.Rngs(0)):
        self.m1 = nnx.Linear(n1, n2, rngs=rngs)
        self.dropout = nnx.Dropout(0.1, rngs=rngs)
        self.m1 = nnx.Linear(n2, n3, rngs=rngs)

    def __call__(self, x):
        x = self.m1(x)
        x = self.dropout(x)
        x = self.m2(x)
        return x

Here is the ipynb to reproduce the crash: https://gist.github.com/vfdev-5/da7f01b9c0c2948046b227a32db2cb2b

Image

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions