123456789101112131415161718192021222324 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- # All rights reserved.
- #
- # This source code is licensed under the BSD-style license found in the
- # LICENSE file in the root directory of this source tree.
- from torch.utils._pytree import tree_flatten, tree_unflatten
- def tree_map_(fn_, pytree):
- flat_args, _ = tree_flatten(pytree)
- [fn_(arg) for arg in flat_args]
- return pytree
- class PlaceHolder():
- def __repr__(self):
- return '*'
- def treespec_pprint(spec):
- leafs = [PlaceHolder() for _ in range(spec.num_leaves)]
- result = tree_unflatten(leafs, spec)
- return repr(result)
|