Skip to content

Commit fc24fb9

Browse files
Jake VanderPlasGoogle-ML-Automation
authored andcommitted
lax.acosh: express gradient using x^2-1 directly.
The (x+1)(x-1) form can be more precise as x->0, but since acosh is only defined for x>=1, we can use the direct form. This saves one op and leads to cleaner second-order autodiff. PiperOrigin-RevId: 816343493
1 parent 6613ad4 commit fc24fb9

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

jax/_src/lax/lax.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4377,7 +4377,9 @@ def atan_impl(x):
43774377

43784378
acosh_p = standard_unop(_float | _complex, 'acosh')
43794379
ad.defjvp(acosh_p,
4380-
lambda g, x: mul(g, rsqrt(mul(sub(x, _one(x)), add(x, _one(x))))))
4380+
# We use x^2-1 rather than (x+1)(x-1). The latter is more accurate
4381+
# for x near zero, but the function domain is x>=1.
4382+
lambda g, x: mul(g, rsqrt(sub(square(x), _one(x)))))
43814383
mlir.register_lowering(acosh_p, partial(_nary_lower_hlo, chlo.acosh))
43824384

43834385
atanh_p = standard_unop(_float | _complex, 'atanh')

0 commit comments

Comments
 (0)