tree_map.py 372 B

123456789101112
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from functorch._C import dim
  7. tree_flatten = dim.tree_flatten
  8. def tree_map(fn, tree):
  9. vs, unflatten = tree_flatten(tree)
  10. return unflatten(fn(v) for v in vs)