Skip to content

REPL latency issue #802

@adamhaber

Description

@adamhaber

I'm running JAX from Jupyter and I've noticed a behaviour I don't understand (might be a bug, probably not):
I'm calling some un-jitted function f which, internally, calls some other jitted function g. g returns some value x, and within f, I set an attribute to be x (I'm ignoring all the details around because I assume they're not important; I can post the entire thing if it helps):

def f(self,...):
    @jit
    def g(...):
        stuff g does
    x = g(...)
    self.factors = x

from Jupyter, I call foo.f(...) and it works fine. The thing I don't understand is this: When I do

%time foo.factors.sum()

it takes several microseconds (this is good); however, it takes roughly one minute for the actual value of foo.factors.sum() to be displayed on screen. Is this an expected behaviour?

Version details:
Jax - 0.1.35 (macOS, so CPU-only version)
Jaxlib - 0.1.16

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionQuestions for the JAX team

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions