123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339 |
- from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, cast, Optional, TypeVar, overload, Union
- import functools
- from collections import namedtuple, OrderedDict
- from dataclasses import dataclass
- T = TypeVar('T')
- S = TypeVar('S')
- U = TypeVar('U')
- R = TypeVar('R')
- """
- Contains utility functions for working with nested python data structures.
- A *pytree* is Python nested data structure. It is a tree in the sense that
- nodes are Python collections (e.g., list, tuple, dict) and the leaves are
- Python values. Furthermore, a pytree should not contain reference cycles.
- pytrees are useful for working with nested collections of Tensors. For example,
- one can use `tree_map` to map a function over all Tensors inside some nested
- collection of Tensors and `tree_unflatten` to get a flat list of all Tensors
- inside some nested collection. pytrees are helpful for implementing nested
- collection support for PyTorch APIs.
- This pytree implementation is not very performant due to Python overhead
- To improve the performance we can move parts of the implementation to C++.
- """
- # A NodeDef holds two callables:
- # - flatten_fn should take the collection and return a flat list of values.
- # It can also return some context that is used in reconstructing the
- # collection.
- # - unflatten_fn should take a flat list of values and some context
- # (returned by flatten_fn). It returns the collection by reconstructing
- # it from the list and the context.
- Context = Any
- PyTree = Any
- FlattenFunc = Callable[[PyTree], Tuple[List, Context]]
- UnflattenFunc = Callable[[List, Context], PyTree]
- class NodeDef(NamedTuple):
- flatten_fn: FlattenFunc
- unflatten_fn: UnflattenFunc
- SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {}
- def _register_pytree_node(typ: Any, flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc) -> None:
- SUPPORTED_NODES[typ] = NodeDef(flatten_fn, unflatten_fn)
- def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
- return list(d.values()), list(d.keys())
- def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
- return {key: value for key, value in zip(context, values)}
- def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]:
- return d, None
- def _list_unflatten(values: List[Any], context: Context) -> List[Any]:
- return list(values)
- def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]:
- return list(d), None
- def _tuple_unflatten(values: List[Any], context: Context) -> Tuple[Any, ...]:
- return tuple(values)
- def _namedtuple_flatten(d: NamedTuple) -> Tuple[List[Any], Context]:
- return list(d), type(d)
- def _namedtuple_unflatten(values: List[Any], context: Context) -> NamedTuple:
- return cast(NamedTuple, context(*values))
- def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Context]:
- return list(d.values()), list(d.keys())
- def _odict_unflatten(values: List[Any], context: Context) -> 'OrderedDict[Any, Any]':
- return OrderedDict((key, value) for key, value in zip(context, values))
- _register_pytree_node(dict, _dict_flatten, _dict_unflatten)
- _register_pytree_node(list, _list_flatten, _list_unflatten)
- _register_pytree_node(tuple, _tuple_flatten, _tuple_unflatten)
- _register_pytree_node(namedtuple, _namedtuple_flatten, _namedtuple_unflatten)
- _register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
- # h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
- def _is_namedtuple_instance(pytree: Any) -> bool:
- typ = type(pytree)
- bases = typ.__bases__
- if len(bases) != 1 or bases[0] != tuple:
- return False
- fields = getattr(typ, '_fields', None)
- if not isinstance(fields, tuple):
- return False
- return all(type(entry) == str for entry in fields)
- def _get_node_type(pytree: Any) -> Any:
- if _is_namedtuple_instance(pytree):
- return namedtuple
- return type(pytree)
- # A leaf is defined as anything that is not a Node.
- def _is_leaf(pytree: PyTree) -> bool:
- return _get_node_type(pytree) not in SUPPORTED_NODES.keys()
- # A TreeSpec represents the structure of a pytree. It holds:
- # "type": the type of root Node of the pytree
- # context: some context that is useful in unflattening the pytree
- # children_specs: specs for each child of the root Node
- # num_leaves: the number of leaves
- @dataclass
- class TreeSpec:
- type: Any
- context: Context
- children_specs: List['TreeSpec']
- def __post_init__(self) -> None:
- self.num_leaves: int = sum([spec.num_leaves for spec in self.children_specs])
- def __repr__(self, indent: int = 0) -> str:
- repr_prefix: str = f'TreeSpec({self.type.__name__}, {self.context}, ['
- children_specs_str: str = ''
- if len(self.children_specs):
- indent += len(repr_prefix)
- children_specs_str += self.children_specs[0].__repr__(indent)
- children_specs_str += ',' if len(self.children_specs) > 1 else ''
- children_specs_str += ','.join(['\n' + ' ' * indent + child.__repr__(indent) for child in self.children_specs[1:]])
- repr_suffix: str = f'{children_specs_str}])'
- return repr_prefix + repr_suffix
- class LeafSpec(TreeSpec):
- def __init__(self) -> None:
- super().__init__(None, None, [])
- self.num_leaves = 1
- def __repr__(self, indent: int = 0) -> str:
- return '*'
- def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]:
- """Flattens a pytree into a list of values and a TreeSpec that can be used
- to reconstruct the pytree.
- """
- if _is_leaf(pytree):
- return [pytree], LeafSpec()
- node_type = _get_node_type(pytree)
- flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
- child_pytrees, context = flatten_fn(pytree)
- # Recursively flatten the children
- result : List[Any] = []
- children_specs : List['TreeSpec'] = []
- for child in child_pytrees:
- flat, child_spec = tree_flatten(child)
- result += flat
- children_specs.append(child_spec)
- return result, TreeSpec(node_type, context, children_specs)
- def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree:
- """Given a list of values and a TreeSpec, builds a pytree.
- This is the inverse operation of `tree_flatten`.
- """
- if not isinstance(spec, TreeSpec):
- raise ValueError(
- f'tree_unflatten(values, spec): Expected `spec` to be instance of '
- f'TreeSpec but got item of type {type(spec)}.')
- if len(values) != spec.num_leaves:
- raise ValueError(
- f'tree_unflatten(values, spec): `values` has length {len(values)} '
- f'but the spec refers to a pytree that holds {spec.num_leaves} '
- f'items ({spec}).')
- if isinstance(spec, LeafSpec):
- return values[0]
- unflatten_fn = SUPPORTED_NODES[spec.type].unflatten_fn
- # Recursively unflatten the children
- start = 0
- end = 0
- child_pytrees = []
- for child_spec in spec.children_specs:
- end += child_spec.num_leaves
- child_pytrees.append(tree_unflatten(values[start:end], child_spec))
- start = end
- return unflatten_fn(child_pytrees, spec.context)
- def tree_map(fn: Any, pytree: PyTree) -> PyTree:
- flat_args, spec = tree_flatten(pytree)
- return tree_unflatten([fn(i) for i in flat_args], spec)
- Type2 = Tuple[Type[T], Type[S]]
- Type3 = Tuple[Type[T], Type[S], Type[U]]
- TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]
- Fn3 = Callable[[Union[T, S, U]], R]
- Fn2 = Callable[[Union[T, S]], R]
- Fn = Callable[[T], R]
- FnAny = Callable[[Any], R]
- MapOnlyFn = Callable[[T], Callable[[Any], Any]]
- # These specializations help with type inference on the lambda passed to this
- # function
- @overload
- def map_only(ty: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
- ...
- @overload
- def map_only(ty: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
- ...
- # This specialization is needed for the implementations below that call
- @overload
- def map_only(ty: TypeAny) -> MapOnlyFn[FnAny[Any]]:
- ...
- def map_only(ty: TypeAny) -> MapOnlyFn[FnAny[Any]]:
- """
- Suppose you are writing a tree_map over tensors, leaving everything
- else unchanged. Ordinarily you would have to write:
- def go(t):
- if isinstance(t, Tensor):
- return ...
- else:
- return t
- With this function, you only need to write:
- @map_only(Tensor)
- def go(t):
- return ...
- You can also directly use 'tree_map_only'
- """
- def deco(f: Callable[[T], Any]) -> Callable[[Any], Any]:
- @functools.wraps(f)
- def inner(x: T) -> Any:
- if isinstance(x, ty):
- return f(x)
- else:
- return x
- return inner
- return deco
- @overload
- def tree_map_only(ty: Type[T], fn: Fn[T, Any], pytree: PyTree) -> PyTree:
- ...
- @overload
- def tree_map_only(ty: Type2[T, S], fn: Fn2[T, S, Any], pytree: PyTree) -> PyTree:
- ...
- @overload
- def tree_map_only(ty: Type3[T, S, U], fn: Fn3[T, S, U, Any], pytree: PyTree) -> PyTree:
- ...
- def tree_map_only(ty: TypeAny, fn: FnAny[Any], pytree: PyTree) -> PyTree:
- return tree_map(map_only(ty)(fn), pytree)
- def tree_all(pred: Callable[[Any], bool], pytree: PyTree) -> bool:
- flat_args, _ = tree_flatten(pytree)
- return all(map(pred, flat_args))
- def tree_any(pred: Callable[[Any], bool], pytree: PyTree) -> bool:
- flat_args, _ = tree_flatten(pytree)
- return any(map(pred, flat_args))
- @overload
- def tree_all_only(ty: Type[T], pred: Fn[T, bool], pytree: PyTree) -> bool:
- ...
- @overload
- def tree_all_only(ty: Type2[T, S], pred: Fn2[T, S, bool], pytree: PyTree) -> bool:
- ...
- @overload
- def tree_all_only(ty: Type3[T, S, U], pred: Fn3[T, S, U, bool], pytree: PyTree) -> bool:
- ...
- def tree_all_only(ty: TypeAny, pred: FnAny[bool], pytree: PyTree) -> bool:
- flat_args, _ = tree_flatten(pytree)
- return all(pred(x) for x in flat_args if isinstance(x, ty))
- @overload
- def tree_any_only(ty: Type[T], pred: Fn[T, bool], pytree: PyTree) -> bool:
- ...
- @overload
- def tree_any_only(ty: Type2[T, S], pred: Fn2[T, S, bool], pytree: PyTree) -> bool:
- ...
- def tree_any_only(ty: TypeAny, pred: FnAny[bool], pytree: PyTree) -> bool:
- flat_args, _ = tree_flatten(pytree)
- return any(pred(x) for x in flat_args if isinstance(x, ty))
- # Broadcasts a pytree to the provided TreeSpec and returns the flattened
- # values. If this is not possible, then this function returns None.
- #
- # For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]),
- # would return [0, 0]. This is useful for part of the vmap implementation:
- # a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be
- # broadcastable to the tree structure of `inputs` and we use
- # _broadcast_to_and_flatten to check this.
- def _broadcast_to_and_flatten(pytree: PyTree, spec: TreeSpec) -> Optional[List[Any]]:
- assert isinstance(spec, TreeSpec)
- if _is_leaf(pytree):
- return [pytree] * spec.num_leaves
- if isinstance(spec, LeafSpec):
- return None
- node_type = _get_node_type(pytree)
- if node_type != spec.type:
- return None
- flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
- child_pytrees, ctx = flatten_fn(pytree)
- # Check if the Node is different from the spec
- if len(child_pytrees) != len(spec.children_specs) or ctx != spec.context:
- return None
- # Recursively flatten the children
- result : List[Any] = []
- for child, child_spec in zip(child_pytrees, spec.children_specs):
- flat = _broadcast_to_and_flatten(child, child_spec)
- if flat is not None:
- result += flat
- else:
- return None
- return result
|