Skip to content

Commit d0ecd2a

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pjrt] Do not ignore shape layout in CommonPjRtClient::CreateViewOfDeviceBuffer
This allows JAX to support importing DLPack tensors with non-default layouts PiperOrigin-RevId: 816385564
1 parent b4f2676 commit d0ecd2a

File tree

4 files changed

+34
-38
lines changed

4 files changed

+34
-38
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
6464
* Changes
6565
* `jax.grad` and `jax.vjp` will now round always primals to float32 if float64
6666
mode is not enabled.
67-
67+
* {func}`jax.dlpack.from_dlpack` now accepts arrays with non-default layouts,
68+
for example, transposed.
6869

6970
## JAX 0.7.2 (September 16, 2025)
7071

jaxlib/dlpack.cc

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -347,22 +347,6 @@ absl::StatusOr<nb::object> DLPackManagedTensorToBuffer(
347347
}
348348
xla::Shape shape = xla::ShapeUtil::MakeShapeWithDenseLayout(
349349
element_type, dimensions, minor_to_major);
350-
// Raise an error if the resulting xla::PjRtBuffer would have a non-default
351-
// layout.
352-
// TODO(skyewm): we do this because JAX doesn't currently have good support
353-
// for non-default layouts, and will return wrong results if a non-default
354-
// layout is passed to a computation expecting default layouts. Remove this
355-
// special case when non-default layouts are better supported by JAX.
356-
TF_ASSIGN_OR_RETURN(xla::Layout default_layout,
357-
device->pjrt_device()->client()->GetDefaultLayout(
358-
element_type, dimensions));
359-
if (shape.layout() != default_layout) {
360-
return xla::Unimplemented(
361-
"from_dlpack got array with non-default layout with minor-to-major "
362-
"dimensions (%s), expected (%s)",
363-
absl::StrJoin(shape.layout().minor_to_major(), ","),
364-
absl::StrJoin(default_layout.minor_to_major(), ","));
365-
}
366350

367351
TF_ASSIGN_OR_RETURN(auto pjrt_buffer,
368352
MakePjrtBuffer(*device->pjrt_device(), dlmt, shape,

tests/array_interoperability_test.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import unittest
1616

1717
from absl.testing import absltest
18+
import numpy as np
1819

1920
import jax
2021
import jax.dlpack
@@ -23,8 +24,7 @@
2324
from jax._src import config
2425
from jax._src import dlpack as dlpack_src
2526
from jax._src import test_util as jtu
26-
27-
import numpy as np
27+
from jax._src.lib import version as jaxlib_version
2828

2929
config.parse_flags_with_absl()
3030

@@ -173,6 +173,15 @@ def testTensorFlowToJaxInt64(self):
173173
dtype_expected = jnp.int64 if config.enable_x64.value else jnp.int32
174174
self.assertEqual(x.dtype, dtype_expected)
175175

176+
@unittest.skipIf(not tf, "Test requires TensorFlow")
177+
def testTensorFlowToJaxNondefaultLayout(self):
178+
if jaxlib_version < (0, 8, 0):
179+
self.skipTest(
180+
"Non-default layout support requires jaxlib 0.8.0 or newer"
181+
)
182+
x = tf.transpose(np.arange(4).reshape(2, 2))
183+
self.assertAllClose(x.numpy(), jax.dlpack.from_dlpack(x))
184+
176185
@jtu.sample_product(shape=all_shapes, dtype=numpy_dtypes, copy=[False, True])
177186
def testNumpyToJax(self, shape, dtype, copy):
178187
rng = jtu.rand_default(self.rng())
@@ -186,28 +195,22 @@ def testNumpyToJax(self, shape, dtype, copy):
186195
else:
187196
self.assertAllClose(x_np, _from_dlpack())
188197

189-
@jtu.sample_product(
190-
shape=all_shapes,
191-
dtype=numpy_dtypes,
192-
)
193-
@jtu.run_on_devices("cpu") # NumPy only accepts cpu DLPacks
198+
def testNumpyToJaxNondefaultLayout(self):
199+
if jaxlib_version < (0, 8, 0):
200+
self.skipTest(
201+
"Non-default layout support requires jaxlib 0.8.0 or newer"
202+
)
203+
x = np.arange(4).reshape(2, 2).T
204+
self.assertAllClose(x, jax.dlpack.from_dlpack(x))
205+
206+
@jtu.sample_product(shape=all_shapes, dtype=numpy_dtypes)
207+
@jtu.run_on_devices("cpu") # NumPy only accepts cpu DLPacks
194208
def testJaxToNumpy(self, shape, dtype):
195209
rng = jtu.rand_default(self.rng())
196210
x_jax = jnp.array(rng(shape, dtype))
197211
x_np = np.from_dlpack(x_jax)
198212
self.assertAllClose(x_np, x_jax)
199213

200-
def testNondefaultLayout(self):
201-
# Generate numpy array with nonstandard layout
202-
a = np.arange(4).reshape(2, 2)
203-
b = a.T
204-
with self.assertRaisesRegex(
205-
RuntimeError,
206-
"from_dlpack got array with non-default layout with minor-to-major "
207-
r"dimensions \(0,1\), expected \(1,0\)",
208-
):
209-
jax.dlpack.from_dlpack(b)
210-
211214

212215
class CudaArrayInterfaceTest(jtu.JaxTestCase):
213216

tests/pytorch_interoperability_test.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
from absl.testing import absltest
1818

1919
import jax
20-
import jax.dlpack
2120
from jax._src import config
2221
from jax._src import test_util as jtu
2322
from jax._src import xla_bridge
23+
from jax._src.lib import version as jaxlib_version
2424
from jax._src.lib import xla_client
25+
import jax.dlpack
2526
import jax.numpy as jnp
2627

2728
config.parse_flags_with_absl()
@@ -114,6 +115,15 @@ def testTorchToJaxInt64(self):
114115
dtype_expected = jnp.int64 if config.enable_x64.value else jnp.int32
115116
self.assertEqual(x.dtype, dtype_expected)
116117

118+
def testTorchToJaxNondefaultLayout(self):
119+
if jaxlib_version < (0, 8, 0):
120+
self.skipTest(
121+
"Non-default layout support requires jaxlib 0.8.0 or newer"
122+
)
123+
x = torch.arange(4).reshape(2, 2).T
124+
x = x.cuda() if jtu.test_device_matches(["gpu"]) else x
125+
self.assertAllClose(x.cpu().numpy(), jax.dlpack.from_dlpack(x))
126+
117127
@jtu.sample_product(shape=all_shapes, dtype=torch_dtypes)
118128
def testTorchToJax(self, shape, dtype):
119129
if not config.enable_x64.value and dtype in [
@@ -130,7 +140,6 @@ def testTorchToJax(self, shape, dtype):
130140
else:
131141
x = torch.tensor(x_np)
132142
x = x.cuda() if jtu.test_device_matches(["gpu"]) else x
133-
x = x.contiguous()
134143
y = jax.dlpack.from_dlpack(x)
135144
self.assertAllClose(x_np, y)
136145

@@ -154,7 +163,6 @@ def testTorchToJaxArray(self, shape, dtype):
154163
else:
155164
x = torch.tensor(x_np)
156165
x = x.cuda() if jtu.test_device_matches(["gpu"]) else x
157-
x = x.contiguous()
158166
y = jax.dlpack.from_dlpack(x)
159167
self.assertAllClose(x_np, y)
160168

0 commit comments

Comments
 (0)