@@ -1550,6 +1550,60 @@ std::optional<std::pair<nb::object, nb::object>> PyTreeDef::GetNodeData()
1550
1550
}
1551
1551
}
1552
1552
1553
+ nb_class_ptr<PyTreeDef> PyTreeDef::FromNodeDataAndChildren (
1554
+ nb_class_ptr<PyTreeRegistry> registry,
1555
+ std::optional<std::pair<nb::object, nb::object>> node_data,
1556
+ nb::iterable children) {
1557
+ nb_class_ptr<PyTreeDef> result =
1558
+ make_nb_class<PyTreeDef>(std::move (registry));
1559
+ int num_leaves = 0 ;
1560
+ int arity = 0 ;
1561
+ for (nb::handle pchild : children) {
1562
+ const PyTreeDef& child = nb::cast<const PyTreeDef&>(pchild);
1563
+ absl::c_copy (child.traversal_ , std::back_inserter (result->traversal_ ));
1564
+ num_leaves += child.num_leaves ();
1565
+ ++arity;
1566
+ }
1567
+ result->traversal_ .emplace_back ();
1568
+ auto & node = result->traversal_ .back ();
1569
+ node.arity = arity;
1570
+ node.custom = nullptr ;
1571
+ node.num_leaves = num_leaves;
1572
+ node.num_nodes = result->traversal_ .size ();
1573
+ if (node_data == std::nullopt ) {
1574
+ node.kind = PyTreeKind::kLeaf ;
1575
+ ++node.num_leaves ;
1576
+ return result;
1577
+ }
1578
+ int is_nt = PyObject_IsSubclass (node_data->first .ptr (),
1579
+ reinterpret_cast <PyObject*>(&PyTuple_Type));
1580
+ if (is_nt == -1 ) {
1581
+ throw nb::python_error ();
1582
+ }
1583
+ if (is_nt != 0 && nb::hasattr (node_data->first , " _fields" )) {
1584
+ node.kind = PyTreeKind::kNamedTuple ;
1585
+ node.node_data = node_data->first ;
1586
+ return result;
1587
+ }
1588
+ auto * registration = result->registry ()->Lookup (node_data->first );
1589
+ if (registration == nullptr ) {
1590
+ throw std::logic_error (absl::StrFormat (
1591
+ " Could not find type: %s." ,
1592
+ nb::cast<absl::string_view>(nb::repr (node_data->first ))));
1593
+ }
1594
+ node.kind = registration->kind ;
1595
+ if (node.kind == PyTreeKind::kCustom || node.kind == PyTreeKind::kDataclass ) {
1596
+ node.custom = registration;
1597
+ node.node_data = node_data->second ;
1598
+ } else if (node.kind == PyTreeKind::kNamedTuple ) {
1599
+ node.node_data = node_data->first ;
1600
+ } else if (node.kind == PyTreeKind::kDict ) {
1601
+ node.sorted_dict_keys =
1602
+ nb::cast<std::vector<nb::object>>(node_data->second );
1603
+ }
1604
+ return result;
1605
+ }
1606
+
1553
1607
int PyTreeDef::Node::tp_traverse (visitproc visit, void * arg) const {
1554
1608
Py_VISIT (node_data.ptr ());
1555
1609
for (const auto & key : sorted_dict_keys) {
@@ -1789,6 +1843,21 @@ void BuildPytreeSubmodule(nb::module_& m) {
1789
1843
treedef.def (" node_data" , &PyTreeDef::GetNodeData,
1790
1844
" Returns None if a leaf-pytree, else (type, node_data)" ,
1791
1845
nb::sig (" def node_data(self) -> tuple[type, Any] | None" ));
1846
+ treedef.def_static (
1847
+ " from_node_data_and_children" ,
1848
+ &PyTreeDef::FromNodeDataAndChildren, nb::arg (" registry" ),
1849
+ nb::arg (" node_data" ).none (), nb::arg (" children" ),
1850
+ " Reconstructs a pytree from `node_data()` and `children()`." ,
1851
+ nb::sig (
1852
+ // clang-format off
1853
+ " def from_node_data_and_children("
1854
+ " self, "
1855
+ " registry: PyTreeRegistry, "
1856
+ " node_data: tuple[type, Any] | None, "
1857
+ " children: typing.Iterable[PyTreeDef]"
1858
+ " ) -> PyTreeDef"
1859
+ // clang-format on
1860
+ ));
1792
1861
treedef.def (" __getstate__" , &PyTreeDef::ToPickle);
1793
1862
treedef.def (" __setstate__" , [](PyTreeDef& t, nb::object o) {
1794
1863
nb::tuple pickle = nb::cast<nb::tuple>(o);
0 commit comments