Skip to content

Commit b4f2676

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[jaxlib] Revived PyTreeDef::FromNodeDataAndChildren
This method is useful for serialization as discussed in #32186. Reverts 0e7c96a PiperOrigin-RevId: 816358350
1 parent fc24fb9 commit b4f2676

File tree

4 files changed

+108
-0
lines changed

4 files changed

+108
-0
lines changed

jaxlib/_jax/pytree.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,12 @@ class PyTreeDef:
121121
def from_iterable_tree(self, __xs: Any): ...
122122
def node_data(self) -> Tuple[type, Any] | None: ...
123123
def children(self) -> list[PyTreeDef]: ...
124+
@staticmethod
125+
def from_node_data_and_children(
126+
registry: PyTreeRegistry,
127+
node_data: Tuple[type, Any] | None,
128+
children: Iterable[PyTreeDef],
129+
) -> PyTreeDef: ...
124130

125131
num_leaves: int
126132
num_nodes: int

jaxlib/pytree.cc

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,6 +1550,60 @@ std::optional<std::pair<nb::object, nb::object>> PyTreeDef::GetNodeData()
15501550
}
15511551
}
15521552

1553+
nb_class_ptr<PyTreeDef> PyTreeDef::FromNodeDataAndChildren(
1554+
nb_class_ptr<PyTreeRegistry> registry,
1555+
std::optional<std::pair<nb::object, nb::object>> node_data,
1556+
nb::iterable children) {
1557+
nb_class_ptr<PyTreeDef> result =
1558+
make_nb_class<PyTreeDef>(std::move(registry));
1559+
int num_leaves = 0;
1560+
int arity = 0;
1561+
for (nb::handle pchild : children) {
1562+
const PyTreeDef& child = nb::cast<const PyTreeDef&>(pchild);
1563+
absl::c_copy(child.traversal_, std::back_inserter(result->traversal_));
1564+
num_leaves += child.num_leaves();
1565+
++arity;
1566+
}
1567+
result->traversal_.emplace_back();
1568+
auto& node = result->traversal_.back();
1569+
node.arity = arity;
1570+
node.custom = nullptr;
1571+
node.num_leaves = num_leaves;
1572+
node.num_nodes = result->traversal_.size();
1573+
if (node_data == std::nullopt) {
1574+
node.kind = PyTreeKind::kLeaf;
1575+
++node.num_leaves;
1576+
return result;
1577+
}
1578+
int is_nt = PyObject_IsSubclass(node_data->first.ptr(),
1579+
reinterpret_cast<PyObject*>(&PyTuple_Type));
1580+
if (is_nt == -1) {
1581+
throw nb::python_error();
1582+
}
1583+
if (is_nt != 0 && nb::hasattr(node_data->first, "_fields")) {
1584+
node.kind = PyTreeKind::kNamedTuple;
1585+
node.node_data = node_data->first;
1586+
return result;
1587+
}
1588+
auto* registration = result->registry()->Lookup(node_data->first);
1589+
if (registration == nullptr) {
1590+
throw std::logic_error(absl::StrFormat(
1591+
"Could not find type: %s.",
1592+
nb::cast<absl::string_view>(nb::repr(node_data->first))));
1593+
}
1594+
node.kind = registration->kind;
1595+
if (node.kind == PyTreeKind::kCustom || node.kind == PyTreeKind::kDataclass) {
1596+
node.custom = registration;
1597+
node.node_data = node_data->second;
1598+
} else if (node.kind == PyTreeKind::kNamedTuple) {
1599+
node.node_data = node_data->first;
1600+
} else if (node.kind == PyTreeKind::kDict) {
1601+
node.sorted_dict_keys =
1602+
nb::cast<std::vector<nb::object>>(node_data->second);
1603+
}
1604+
return result;
1605+
}
1606+
15531607
int PyTreeDef::Node::tp_traverse(visitproc visit, void* arg) const {
15541608
Py_VISIT(node_data.ptr());
15551609
for (const auto& key : sorted_dict_keys) {
@@ -1789,6 +1843,21 @@ void BuildPytreeSubmodule(nb::module_& m) {
17891843
treedef.def("node_data", &PyTreeDef::GetNodeData,
17901844
"Returns None if a leaf-pytree, else (type, node_data)",
17911845
nb::sig("def node_data(self) -> tuple[type, Any] | None"));
1846+
treedef.def_static(
1847+
"from_node_data_and_children",
1848+
&PyTreeDef::FromNodeDataAndChildren, nb::arg("registry"),
1849+
nb::arg("node_data").none(), nb::arg("children"),
1850+
"Reconstructs a pytree from `node_data()` and `children()`.",
1851+
nb::sig(
1852+
// clang-format off
1853+
"def from_node_data_and_children("
1854+
"self, "
1855+
"registry: PyTreeRegistry, "
1856+
"node_data: tuple[type, Any] | None, "
1857+
"children: typing.Iterable[PyTreeDef]"
1858+
") -> PyTreeDef"
1859+
// clang-format on
1860+
));
17921861
treedef.def("__getstate__", &PyTreeDef::ToPickle);
17931862
treedef.def("__setstate__", [](PyTreeDef& t, nb::object o) {
17941863
nb::tuple pickle = nb::cast<nb::tuple>(o);

jaxlib/pytree.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,11 @@ class PyTreeDef {
311311
std::optional<std::pair<nanobind::object, nanobind::object>> GetNodeData()
312312
const;
313313

314+
static nb_class_ptr<PyTreeDef> FromNodeDataAndChildren(
315+
nb_class_ptr<PyTreeRegistry> registry,
316+
std::optional<std::pair<nanobind::object, nanobind::object>> node_data,
317+
nanobind::iterable children);
318+
314319
static PyType_Slot slots_[];
315320

316321
private:

jaxlib/pytree_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(self, field0, field1):
3737
def to_iterable(self):
3838
return [self.field0, self.field1], (None,)
3939

40+
4041
def from_iterable(state, values):
4142
del state
4243
return ExampleType2(field0=values[0], field1=values[1])
@@ -79,11 +80,38 @@ def testRegisteredType(self):
7980
with self.assertRaises(ValueError):
8081
self.roundtrip_proto({"a": ExampleType2(field0=o, field1=o)})
8182

83+
def roundtrip_node_data(self, example):
84+
original = registry.flatten(example)[1]
85+
restored = pytree.PyTreeDef.from_node_data_and_children(
86+
registry, original.node_data(), original.children()
87+
)
88+
self.assertEqual(restored, original)
89+
90+
def testRoundtripNodeData(self):
91+
o = object()
92+
self.roundtrip_node_data([o, o, o])
93+
self.roundtrip_node_data((o, o, o))
94+
self.roundtrip_node_data({"a": o, "b": o})
95+
self.roundtrip_node_data({22: o, 88: o})
96+
self.roundtrip_node_data(None)
97+
self.roundtrip_node_data(o)
98+
self.roundtrip_node_data(ExampleType(field0=o, field1=o))
99+
self.roundtrip_node_data(ExampleType2(field0=o, field1=o))
100+
82101
def testCompose(self):
83102
x = registry.flatten(0)[1]
84103
y = registry.flatten((0, 0))[1]
85104
self.assertEqual((x.compose(y)).num_leaves, 2)
86105

106+
def testDataclassMakeFromNodeData(self):
107+
c = Custom(1, "a")
108+
c_leafs, c_tree = registry.flatten(c)
109+
c_tree2 = pytree.PyTreeDef.from_node_data_and_children(
110+
registry, c_tree.node_data(), c_tree.children()
111+
)
112+
self.assertEqual(c_tree2.unflatten(c_leafs), c)
113+
self.assertEqual(str(c_tree2), str(c_tree))
114+
87115
def testTpTraverse(self):
88116
self.assertContainsSubset(
89117
[

0 commit comments

Comments
 (0)