_pytree.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, cast, Optional, TypeVar, overload, Union
  2. import functools
  3. from collections import namedtuple, OrderedDict
  4. from dataclasses import dataclass
  5. T = TypeVar('T')
  6. S = TypeVar('S')
  7. U = TypeVar('U')
  8. R = TypeVar('R')
  9. """
  10. Contains utility functions for working with nested python data structures.
  11. A *pytree* is Python nested data structure. It is a tree in the sense that
  12. nodes are Python collections (e.g., list, tuple, dict) and the leaves are
  13. Python values. Furthermore, a pytree should not contain reference cycles.
  14. pytrees are useful for working with nested collections of Tensors. For example,
  15. one can use `tree_map` to map a function over all Tensors inside some nested
  16. collection of Tensors and `tree_unflatten` to get a flat list of all Tensors
  17. inside some nested collection. pytrees are helpful for implementing nested
  18. collection support for PyTorch APIs.
  19. This pytree implementation is not very performant due to Python overhead
  20. To improve the performance we can move parts of the implementation to C++.
  21. """
  22. # A NodeDef holds two callables:
  23. # - flatten_fn should take the collection and return a flat list of values.
  24. # It can also return some context that is used in reconstructing the
  25. # collection.
  26. # - unflatten_fn should take a flat list of values and some context
  27. # (returned by flatten_fn). It returns the collection by reconstructing
  28. # it from the list and the context.
  29. Context = Any
  30. PyTree = Any
  31. FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
  32. UnflattenFunc = Callable[[List, Context], PyTree]
  33. class NodeDef(NamedTuple):
  34. flatten_fn: FlattenFunc
  35. unflatten_fn: UnflattenFunc
  36. SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {}
  37. def _register_pytree_node(typ: Any, flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc) -> None:
  38. SUPPORTED_NODES[typ] = NodeDef(flatten_fn, unflatten_fn)
  39. def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
  40. return list(d.values()), list(d.keys())
  41. def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
  42. return {key: value for key, value in zip(context, values)}
  43. def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]:
  44. return d, None
  45. def _list_unflatten(values: List[Any], context: Context) -> List[Any]:
  46. return list(values)
  47. def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]:
  48. return list(d), None
  49. def _tuple_unflatten(values: List[Any], context: Context) -> Tuple[Any, ...]:
  50. return tuple(values)
  51. def _namedtuple_flatten(d: NamedTuple) -> Tuple[List[Any], Context]:
  52. return list(d), type(d)
  53. def _namedtuple_unflatten(values: List[Any], context: Context) -> NamedTuple:
  54. return cast(NamedTuple, context(*values))
  55. def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Context]:
  56. return list(d.values()), list(d.keys())
  57. def _odict_unflatten(values: List[Any], context: Context) -> 'OrderedDict[Any, Any]':
  58. return OrderedDict((key, value) for key, value in zip(context, values))
  59. _register_pytree_node(dict, _dict_flatten, _dict_unflatten)
  60. _register_pytree_node(list, _list_flatten, _list_unflatten)
  61. _register_pytree_node(tuple, _tuple_flatten, _tuple_unflatten)
  62. _register_pytree_node(namedtuple, _namedtuple_flatten, _namedtuple_unflatten)
  63. _register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
  64. # h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
  65. def _is_namedtuple_instance(pytree: Any) -> bool:
  66. typ = type(pytree)
  67. bases = typ.__bases__
  68. if len(bases) != 1 or bases[0] != tuple:
  69. return False
  70. fields = getattr(typ, '_fields', None)
  71. if not isinstance(fields, tuple):
  72. return False
  73. return all(type(entry) == str for entry in fields)
  74. def _get_node_type(pytree: Any) -> Any:
  75. if _is_namedtuple_instance(pytree):
  76. return namedtuple
  77. return type(pytree)
  78. # A leaf is defined as anything that is not a Node.
  79. def _is_leaf(pytree: PyTree) -> bool:
  80. return _get_node_type(pytree) not in SUPPORTED_NODES.keys()
  81. # A TreeSpec represents the structure of a pytree. It holds:
  82. # "type": the type of root Node of the pytree
  83. # context: some context that is useful in unflattening the pytree
  84. # children_specs: specs for each child of the root Node
  85. # num_leaves: the number of leaves
  86. @dataclass
  87. class TreeSpec:
  88. type: Any
  89. context: Context
  90. children_specs: List['TreeSpec']
  91. def __post_init__(self) -> None:
  92. self.num_leaves: int = sum([spec.num_leaves for spec in self.children_specs])
  93. def __repr__(self, indent: int = 0) -> str:
  94. repr_prefix: str = f'TreeSpec({self.type.__name__}, {self.context}, ['
  95. children_specs_str: str = ''
  96. if len(self.children_specs):
  97. indent += len(repr_prefix)
  98. children_specs_str += self.children_specs[0].__repr__(indent)
  99. children_specs_str += ',' if len(self.children_specs) > 1 else ''
  100. children_specs_str += ','.join(['\n' + ' ' * indent + child.__repr__(indent) for child in self.children_specs[1:]])
  101. repr_suffix: str = f'{children_specs_str}])'
  102. return repr_prefix + repr_suffix
  103. class LeafSpec(TreeSpec):
  104. def __init__(self) -> None:
  105. super().__init__(None, None, [])
  106. self.num_leaves = 1
  107. def __repr__(self, indent: int = 0) -> str:
  108. return '*'
  109. def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]:
  110. """Flattens a pytree into a list of values and a TreeSpec that can be used
  111. to reconstruct the pytree.
  112. """
  113. if _is_leaf(pytree):
  114. return [pytree], LeafSpec()
  115. node_type = _get_node_type(pytree)
  116. flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
  117. child_pytrees, context = flatten_fn(pytree)
  118. # Recursively flatten the children
  119. result : List[Any] = []
  120. children_specs : List['TreeSpec'] = []
  121. for child in child_pytrees:
  122. flat, child_spec = tree_flatten(child)
  123. result += flat
  124. children_specs.append(child_spec)
  125. return result, TreeSpec(node_type, context, children_specs)
  126. def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree:
  127. """Given a list of values and a TreeSpec, builds a pytree.
  128. This is the inverse operation of `tree_flatten`.
  129. """
  130. if not isinstance(spec, TreeSpec):
  131. raise ValueError(
  132. f'tree_unflatten(values, spec): Expected `spec` to be instance of '
  133. f'TreeSpec but got item of type {type(spec)}.')
  134. if len(values) != spec.num_leaves:
  135. raise ValueError(
  136. f'tree_unflatten(values, spec): `values` has length {len(values)} '
  137. f'but the spec refers to a pytree that holds {spec.num_leaves} '
  138. f'items ({spec}).')
  139. if isinstance(spec, LeafSpec):
  140. return values[0]
  141. unflatten_fn = SUPPORTED_NODES[spec.type].unflatten_fn
  142. # Recursively unflatten the children
  143. start = 0
  144. end = 0
  145. child_pytrees = []
  146. for child_spec in spec.children_specs:
  147. end += child_spec.num_leaves
  148. child_pytrees.append(tree_unflatten(values[start:end], child_spec))
  149. start = end
  150. return unflatten_fn(child_pytrees, spec.context)
  151. def tree_map(fn: Any, pytree: PyTree) -> PyTree:
  152. flat_args, spec = tree_flatten(pytree)
  153. return tree_unflatten([fn(i) for i in flat_args], spec)
  154. Type2 = Tuple[Type[T], Type[S]]
  155. Type3 = Tuple[Type[T], Type[S], Type[U]]
  156. TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]
  157. Fn3 = Callable[[Union[T, S, U]], R]
  158. Fn2 = Callable[[Union[T, S]], R]
  159. Fn = Callable[[T], R]
  160. FnAny = Callable[[Any], R]
  161. MapOnlyFn = Callable[[T], Callable[[Any], Any]]
  162. # These specializations help with type inference on the lambda passed to this
  163. # function
  164. @overload
  165. def map_only(ty: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
  166. ...
  167. @overload
  168. def map_only(ty: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
  169. ...
  170. # This specialization is needed for the implementations below that call
  171. @overload
  172. def map_only(ty: TypeAny) -> MapOnlyFn[FnAny[Any]]:
  173. ...
  174. def map_only(ty: TypeAny) -> MapOnlyFn[FnAny[Any]]:
  175. """
  176. Suppose you are writing a tree_map over tensors, leaving everything
  177. else unchanged. Ordinarily you would have to write:
  178. def go(t):
  179. if isinstance(t, Tensor):
  180. return ...
  181. else:
  182. return t
  183. With this function, you only need to write:
  184. @map_only(Tensor)
  185. def go(t):
  186. return ...
  187. You can also directly use 'tree_map_only'
  188. """
  189. def deco(f: Callable[[T], Any]) -> Callable[[Any], Any]:
  190. @functools.wraps(f)
  191. def inner(x: T) -> Any:
  192. if isinstance(x, ty):
  193. return f(x)
  194. else:
  195. return x
  196. return inner
  197. return deco
  198. @overload
  199. def tree_map_only(ty: Type[T], fn: Fn[T, Any], pytree: PyTree) -> PyTree:
  200. ...
  201. @overload
  202. def tree_map_only(ty: Type2[T, S], fn: Fn2[T, S, Any], pytree: PyTree) -> PyTree:
  203. ...
  204. @overload
  205. def tree_map_only(ty: Type3[T, S, U], fn: Fn3[T, S, U, Any], pytree: PyTree) -> PyTree:
  206. ...
  207. def tree_map_only(ty: TypeAny, fn: FnAny[Any], pytree: PyTree) -> PyTree:
  208. return tree_map(map_only(ty)(fn), pytree)
  209. def tree_all(pred: Callable[[Any], bool], pytree: PyTree) -> bool:
  210. flat_args, _ = tree_flatten(pytree)
  211. return all(map(pred, flat_args))
  212. def tree_any(pred: Callable[[Any], bool], pytree: PyTree) -> bool:
  213. flat_args, _ = tree_flatten(pytree)
  214. return any(map(pred, flat_args))
  215. @overload
  216. def tree_all_only(ty: Type[T], pred: Fn[T, bool], pytree: PyTree) -> bool:
  217. ...
  218. @overload
  219. def tree_all_only(ty: Type2[T, S], pred: Fn2[T, S, bool], pytree: PyTree) -> bool:
  220. ...
  221. @overload
  222. def tree_all_only(ty: Type3[T, S, U], pred: Fn3[T, S, U, bool], pytree: PyTree) -> bool:
  223. ...
  224. def tree_all_only(ty: TypeAny, pred: FnAny[bool], pytree: PyTree) -> bool:
  225. flat_args, _ = tree_flatten(pytree)
  226. return all(pred(x) for x in flat_args if isinstance(x, ty))
  227. @overload
  228. def tree_any_only(ty: Type[T], pred: Fn[T, bool], pytree: PyTree) -> bool:
  229. ...
  230. @overload
  231. def tree_any_only(ty: Type2[T, S], pred: Fn2[T, S, bool], pytree: PyTree) -> bool:
  232. ...
  233. def tree_any_only(ty: TypeAny, pred: FnAny[bool], pytree: PyTree) -> bool:
  234. flat_args, _ = tree_flatten(pytree)
  235. return any(pred(x) for x in flat_args if isinstance(x, ty))
  236. # Broadcasts a pytree to the provided TreeSpec and returns the flattened
  237. # values. If this is not possible, then this function returns None.
  238. #
  239. # For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]),
  240. # would return [0, 0]. This is useful for part of the vmap implementation:
  241. # a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be
  242. # broadcastable to the tree structure of `inputs` and we use
  243. # _broadcast_to_and_flatten to check this.
  244. def _broadcast_to_and_flatten(pytree: PyTree, spec: TreeSpec) -> Optional[List[Any]]:
  245. assert isinstance(spec, TreeSpec)
  246. if _is_leaf(pytree):
  247. return [pytree] * spec.num_leaves
  248. if isinstance(spec, LeafSpec):
  249. return None
  250. node_type = _get_node_type(pytree)
  251. if node_type != spec.type:
  252. return None
  253. flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
  254. child_pytrees, ctx = flatten_fn(pytree)
  255. # Check if the Node is different from the spec
  256. if len(child_pytrees) != len(spec.children_specs) or ctx != spec.context:
  257. return None
  258. # Recursively flatten the children
  259. result : List[Any] = []
  260. for child, child_spec in zip(child_pytrees, spec.children_specs):
  261. flat = _broadcast_to_and_flatten(child, child_spec)
  262. if flat is not None:
  263. result += flat
  264. else:
  265. return None
  266. return result