Skip to content

pow grad has surprising behavior based on dtype #17995

@mattjj

Description

@mattjj

Hi @mattjj, thanks for this. I'm finding this choice very confusing, since at the back of my head, I'm dealing with a function in one variable, x, and so it's unexpected that for every expression I write down inside such a function, I need to check to make sure what I think are constants are well-behaved if they were to become variables. But perhaps there are good reasons why this should be.

I'm still running into trouble along these lines though, I think when arange is used to build the polynomial. Here is a simple example that breaks:

import jax
import jax.numpy as jnp

b = lambda z: jnp.sum(z**jnp.arange(0, 2))
d = lambda z: 1. + z
print(b(2.543), d(2.543)) # check they are the same and I'm not crazy
print(jax.grad(d)(1.))  # gives 1. as expected
print(jax.grad(b)(1.)) # gives 2.

Originally posted by @hongwanliu in #14397 (comment)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions