Skip to content

jax.grad computes incorrect derivative for polynomials #14397

@wuxishy

Description

@wuxishy

Description

jax.grad does not handle constant in a polynomial correctly and results in nan when differentiating at 0.
Here is an example where differentiating x^2 + x + 1 at 0 results in nan

def f(x):
    return jnp.sum(x**jnp.arange(3))
 
jax.grad(f)(0.0)
# output is Array(nan, dtype=float32, weak_type=True)

The issue is that the derivative subtracts 1 from all exponents and results in computing the expression 0 * 1/x
To illustrate this, here is the resulting jax expression:

>>> jax.make_jaxpr(jax.grad(f))(0.0)
{ lambda ; a:f32[]. let
    b:i32[3] = iota[dimension=0 dtype=int32 shape=(3,)] 
    c:f32[3] = convert_element_type[new_dtype=float32 weak_type=True] b
    d:f32[3] = pow a c
    e:f32[3] = broadcast_in_dim[broadcast_dimensions=() shape=(3,)] 1.0
    f:f32[3] = sub c e
    g:f32[3] = pow a f
    h:f32[3] = mul c g
    i:f32[3] = convert_element_type[new_dtype=float32 weak_type=False] d
    _:f32[] = reduce_sum[axes=(0,)] i
    j:f32[3] = broadcast_in_dim[broadcast_dimensions=() shape=(3,)] 1.0
    k:f32[3] = convert_element_type[new_dtype=float32 weak_type=True] j
    l:f32[3] = mul k h
    m:f32[] = reduce_sum[axes=(0,)] l
  in (m,) }

I also noticed that this bug does not exist in earlier version of jax (I checked jax 0.2.10 w/ jaxlib 0.1.62).

What jax/jaxlib version are you using?

jax 0.4.3, jaxlib 0.4.3

Which accelerator(s) are you using?

CPU

Additional system info

No response

NVIDIA GPU info

No response

Metadata

Metadata

Assignees

Labels

P1 (soon)Assignee is working on this now, among other tasks. (Assignee required)bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions