Skip to content

Commit cbe9976

Browse files
committed
[pytree] add another simplified pytree module torch.pytree
Differences between `torch.pytree` and `torch.utils.pytree`: 1. APIs in `torch.utils.pytree` have a `tree_` prefix: ```python leaves, treespec = torch.utils.pytree.tree_flatten(tree) new_tree = torch.utils.pytree.tree_map(func, tree) leaevs, treespec = torch.pytree.flatten(tree) new_tree = torch.pytree.map(func, tree) ``` 2. The argument order of `unflatten` is reversed for better `functools.partial` support: ```python tree = torch.utils.pytree.tree_unflatten(leaves, treespec) tree = torch.pytree.unflatten(treespec, leaves) unflatten_fn = functools.partial(torch.pytree.unflatten, treespec) tree1 = unflatten_fn(leaves1) tree2 = unflatten_fn(leaves2) ``` This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`. ghstack-source-id: 0bf5096 Pull Request resolved: #148180
1 parent 8614658 commit cbe9976

File tree

4 files changed

+120
-0
lines changed

4 files changed

+120
-0
lines changed

CODEOWNERS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,4 +188,5 @@ torch/backends/cudnn/ @eqy @syed-ahmed
188188
/torch/utils/_pytree.py @XuehaiPan
189189
/torch/utils/_cxx_pytree.py @XuehaiPan
190190
/torch/utils/pytree/ @XuehaiPan
191+
/torch/pytree.py @XuehaiPan
191192
/torch/_dynamo/polyfills/pytree.py @XuehaiPan

test/allowlist_for_publicAPI.json

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,21 @@
694694
"kineto_available",
695695
"record_function"
696696
],
697+
"torch.pytree": [
698+
"register_node",
699+
"all",
700+
"all_only",
701+
"any",
702+
"any_only",
703+
"flatten",
704+
"iter",
705+
"leaves",
706+
"map",
707+
"map_",
708+
"map_only",
709+
"map_only_",
710+
"structure"
711+
],
697712
"torch.quantization": [
698713
"ABC",
699714
"DeQuantStub",

torch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2651,6 +2651,7 @@ def registerOp(cls, op_key, full_schema, op_impl, dispatch_key):
26512651
_inductor as _inductor,
26522652
_subclasses as _subclasses,
26532653
onnx as onnx,
2654+
pytree as pytree,
26542655
)
26552656

26562657
else:
@@ -2660,6 +2661,7 @@ def registerOp(cls, op_key, full_schema, op_impl, dispatch_key):
26602661
"_export",
26612662
# ONNX must be imported after _dynamo, _ops, _subclasses, fx, func and jit
26622663
"onnx",
2664+
"pytree",
26632665
}
26642666

26652667
def __getattr__(name):

torch/pytree.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Owner(s): ["module: pytree"]
2+
3+
"""
4+
Contains utility functions for working with nested python data structures.
5+
6+
A *pytree* is Python nested data structure. It is a tree in the sense that
7+
nodes are Python collections (e.g., list, tuple, dict) and the leaves are
8+
Python values. Furthermore, a pytree should not contain reference cycles.
9+
10+
pytrees are useful for working with nested collections of Tensors. For example,
11+
one can use `map` to map a function over all Tensors inside some nested
12+
collection of Tensors and `leaves` to get a flat list of all Tensors
13+
inside some nested collection. pytrees are helpful for implementing nested
14+
collection support for PyTorch APIs.
15+
"""
16+
17+
from __future__ import annotations
18+
19+
from typing import Any as _Any, TYPE_CHECKING as _TYPE_CHECKING
20+
21+
import torch
22+
from torch.utils.pytree import (
23+
register_pytree_node as register_node,
24+
tree_all as all,
25+
tree_all_only as all_only,
26+
tree_any as any,
27+
tree_any_only as any_only,
28+
tree_flatten as flatten,
29+
tree_iter as iter,
30+
tree_leaves as leaves,
31+
tree_map as map,
32+
tree_map_ as map_,
33+
tree_map_only as map_only,
34+
tree_map_only_ as map_only_,
35+
tree_structure as structure,
36+
)
37+
38+
39+
if _TYPE_CHECKING:
40+
from collections.abc import Iterable
41+
42+
from torch.utils._cxx_pytree import PyTree as PyTree, PyTreeSpec as PyTreeSpec
43+
44+
45+
__all__ = [
46+
"PyTreeSpec",
47+
"register_node",
48+
"flatten",
49+
"unflatten",
50+
"iter",
51+
"leaves",
52+
"structure",
53+
"map",
54+
"map_",
55+
"map_only",
56+
"map_only_",
57+
"all",
58+
"any",
59+
"all_only",
60+
"any_only",
61+
]
62+
63+
64+
def unflatten(treespec: PyTreeSpec, leaves: Iterable[_Any]) -> PyTree:
65+
"""Reconstruct a pytree from the treespec and the leaves.
66+
67+
The inverse of :func:`flatten`.
68+
69+
>>> tree = {"b": (2, [3, 4]), "a": 1, "c": None, "d": 5}
70+
>>> leaves, treespec = torch.pytree.flatten(tree)
71+
>>> tree == torch.pytree.unflatten(treespec, leaves)
72+
True
73+
74+
.. note::
75+
76+
This function has a different signature than :func:`torch.utils.pytree.tree_unflatten`.
77+
The ``treespec`` argument comes first to have a better :class:`functools.partial` support:
78+
79+
.. code-block:: python
80+
81+
import functools
82+
83+
unflatten_fn = functools.partial(unflatten, treespec)
84+
tree1 = unflatten_fn(leaves1)
85+
tree2 = unflatten_fn(leaves2)
86+
87+
Args:
88+
treespec (PyTreeSpec): The treespec to reconstruct.
89+
leaves (iterable): The list of leaves to use for reconstruction. The list must match the
90+
number of leaves of the treespec.
91+
92+
Returns:
93+
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
94+
``treespec``.
95+
"""
96+
return torch.utils.pytree.tree_unflatten(leaves, treespec)
97+
98+
99+
def __getattr__(name: str) -> _Any:
100+
if name in ("PyTreeSpec", "TreeSpec"):
101+
return torch.utils.pytree.PyTreeSpec
102+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

0 commit comments

Comments
 (0)