-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Closed
Labels
questionQuestions for the JAX teamQuestions for the JAX team
Description
I'm running JAX from Jupyter and I've noticed a behaviour I don't understand (might be a bug, probably not):
I'm calling some un-jitted function f
which, internally, calls some other jitted function g
. g
returns some value x
, and within f
, I set an attribute to be x
(I'm ignoring all the details around because I assume they're not important; I can post the entire thing if it helps):
def f(self,...):
@jit
def g(...):
stuff g does
x = g(...)
self.factors = x
from Jupyter, I call foo.f(...)
and it works fine. The thing I don't understand is this: When I do
%time foo.factors.sum()
it takes several microseconds (this is good); however, it takes roughly one minute for the actual value of foo.factors.sum()
to be displayed on screen. Is this an expected behaviour?
Version details:
Jax - 0.1.35 (macOS, so CPU-only version)
Jaxlib - 0.1.16
Metadata
Metadata
Assignees
Labels
questionQuestions for the JAX teamQuestions for the JAX team