-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Closed
Description
Hi @mattjj, thanks for this. I'm finding this choice very confusing, since at the back of my head, I'm dealing with a function in one variable, x, and so it's unexpected that for every expression I write down inside such a function, I need to check to make sure what I think are constants are well-behaved if they were to become variables. But perhaps there are good reasons why this should be.
I'm still running into trouble along these lines though, I think when
arange
is used to build the polynomial. Here is a simple example that breaks:
import jax
import jax.numpy as jnp
b = lambda z: jnp.sum(z**jnp.arange(0, 2))
d = lambda z: 1. + z
print(b(2.543), d(2.543)) # check they are the same and I'm not crazy
print(jax.grad(d)(1.)) # gives 1. as expected
print(jax.grad(b)(1.)) # gives 2.
Originally posted by @hongwanliu in #14397 (comment)
Metadata
Metadata
Assignees
Labels
No labels