Skip to content

Transient bug in Travis about escaped tracers in examples/control_test #2507

@gnecula

Description

@gnecula

Sometimes in Travis we see the error below. It is suspicious because it sometimes only shows for x64 tests, and if I re-run the tests the error disappears. It is probably some state corruption somewhere.

examples/control.py:180: in lqr_predict
    _, X, U = fori_loop(0, T, fwd_loop, (spec, X, U))
examples/control.py:76: in fori_loop
    return lax.fori_loop(lo, hi, loop, init)
jax/lax/lax_control_flow.py:170: in fori_loop
    (lower, upper, init_val))
jax/lax/lax_control_flow.py:227: in while_loop
    body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals)
jax/lax/lax_control_flow.py:61: in _initial_style_jaxpr
    wrapped_fun, in_pvals, instantiate=True, stage_out_calls=True)
jax/interpreters/partial_eval.py:393: in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
jax/linear_util.py:154: in call_wrapped
    ans = gen.send(ans)
jax/interpreters/partial_eval.py:407: in trace_to_subjaxpr
    out_tracers = map(trace.full_raise, map(core.full_lower, ans))
    
E     jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
E     The functions being transformed should not save traced values to global state.
E     Details: Can't lift level Traced<ShapedArray(float64[11,2]):JaxprTrace(level=11/0)> to JaxprTrace(level=8/0).
jax/core.py:348: UnexpectedTracerError
=========================== short test summary info ============================
FAILED examples/control_test.py::ControlExampleTest::testIlqrWithLqrProblemSpecifiedGenerally

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions