-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Open
Labels
bugSomething isn't workingSomething isn't workingdocumentationopenIssues intentionally left open, with no schedule for next steps.Issues intentionally left open, with no schedule for next steps.
Description
jax.scipy.ndimage.map_coordinates behaves differently than scipy.ndimage.map_coordinates when using mode="constant". In this example we can see that 2.9 which should be as "off the array" as 3, produces a different output which is in fact wrong (1) but should in fact produce what the original implementation produces (0, which is cval).
from scipy import ndimage
import jax
import jax.numpy as jnp
src = jnp.array([5, 1, 6]) # Fails!
# src = jnp.array([5, 1, 5]) # Doesn't fail
# Since 2.9 is first jnp.floored, this gives "2". Which is a valid index.
coords = [[2, 3, 2.9]]
def scipy_map_coordinates():
return ndimage.map_coordinates(src, coords, order=1, mode="constant")
def jax_map_coordinates():
return jax.scipy.ndimage.map_coordinates(
src, coords, order=1, mode="constant")
print(scipy_map_coordinates())
print(jax_map_coordinates())
assert jnp.array_equal(scipy_map_coordinates(), jax_map_coordinates())
Yielding:
[6 0 0]
[6 0 1]
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-6-b13f0479754d> in <module>()
20 print(jax_map_coordinates())
21
---> 22 assert jnp.array_equal(scipy_map_coordinates(), jax_map_coordinates())
AssertionError:
CC: @claudiofantacci
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingdocumentationopenIssues intentionally left open, with no schedule for next steps.Issues intentionally left open, with no schedule for next steps.