pytree_hacks.py 612 B

123456789101112131415161718192021222324
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from torch.utils._pytree import tree_flatten, tree_unflatten
  7. def tree_map_(fn_, pytree):
  8. flat_args, _ = tree_flatten(pytree)
  9. [fn_(arg) for arg in flat_args]
  10. return pytree
  11. class PlaceHolder():
  12. def __repr__(self):
  13. return '*'
  14. def treespec_pprint(spec):
  15. leafs = [PlaceHolder() for _ in range(spec.num_leaves)]
  16. result = tree_unflatten(leafs, spec)
  17. return repr(result)