-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Closed
Labels
P1 (soon)Assignee is working on this now, among other tasks. (Assignee required)Assignee is working on this now, among other tasks. (Assignee required)bugSomething isn't workingSomething isn't working
Description
Description
jax.grad
does not handle constant in a polynomial correctly and results in nan
when differentiating at 0.
Here is an example where differentiating x^2 + x + 1
at 0 results in nan
def f(x):
return jnp.sum(x**jnp.arange(3))
jax.grad(f)(0.0)
# output is Array(nan, dtype=float32, weak_type=True)
The issue is that the derivative subtracts 1 from all exponents and results in computing the expression 0 * 1/x
To illustrate this, here is the resulting jax expression:
>>> jax.make_jaxpr(jax.grad(f))(0.0)
{ lambda ; a:f32[]. let
b:i32[3] = iota[dimension=0 dtype=int32 shape=(3,)]
c:f32[3] = convert_element_type[new_dtype=float32 weak_type=True] b
d:f32[3] = pow a c
e:f32[3] = broadcast_in_dim[broadcast_dimensions=() shape=(3,)] 1.0
f:f32[3] = sub c e
g:f32[3] = pow a f
h:f32[3] = mul c g
i:f32[3] = convert_element_type[new_dtype=float32 weak_type=False] d
_:f32[] = reduce_sum[axes=(0,)] i
j:f32[3] = broadcast_in_dim[broadcast_dimensions=() shape=(3,)] 1.0
k:f32[3] = convert_element_type[new_dtype=float32 weak_type=True] j
l:f32[3] = mul k h
m:f32[] = reduce_sum[axes=(0,)] l
in (m,) }
I also noticed that this bug does not exist in earlier version of jax (I checked jax 0.2.10 w/ jaxlib 0.1.62).
What jax/jaxlib version are you using?
jax 0.4.3, jaxlib 0.4.3
Which accelerator(s) are you using?
CPU
Additional system info
No response
NVIDIA GPU info
No response
Metadata
Metadata
Assignees
Labels
P1 (soon)Assignee is working on this now, among other tasks. (Assignee required)Assignee is working on this now, among other tasks. (Assignee required)bugSomething isn't workingSomething isn't working