Skip to content

map_coordinates mode='constant' not properly applied #5687

@agudallago

Description

@agudallago

jax.scipy.ndimage.map_coordinates behaves differently than scipy.ndimage.map_coordinates when using mode="constant". In this example we can see that 2.9 which should be as "off the array" as 3, produces a different output which is in fact wrong (1) but should in fact produce what the original implementation produces (0, which is cval).

from scipy import ndimage
import jax

import jax.numpy as jnp

src = jnp.array([5, 1, 6])  # Fails!
# src = jnp.array([5, 1, 5]) # Doesn't fail

# Since 2.9 is first jnp.floored, this gives "2". Which is a valid index.
coords = [[2, 3, 2.9]]


def scipy_map_coordinates():
  return ndimage.map_coordinates(src, coords, order=1, mode="constant")


def jax_map_coordinates():
  return jax.scipy.ndimage.map_coordinates(
      src, coords, order=1, mode="constant")


print(scipy_map_coordinates())
print(jax_map_coordinates())

assert jnp.array_equal(scipy_map_coordinates(), jax_map_coordinates())

Yielding:

[6 0 0]
[6 0 1]
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-6-b13f0479754d> in <module>()
     20 print(jax_map_coordinates())
     21 
---> 22 assert jnp.array_equal(scipy_map_coordinates(), jax_map_coordinates())

AssertionError: 

CC: @claudiofantacci

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingdocumentationopenIssues intentionally left open, with no schedule for next steps.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions