123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 |
- from collections import OrderedDict
- """
- This file contains helper functions that implement experimental functionality
- for named tensors in python. All of these are experimental, unstable, and
- subject to change or deletion.
- """
- def check_serializing_named_tensor(tensor):
- if tensor.has_names():
- raise RuntimeError(
- "NYI: Named tensors don't support serialization. Please drop "
- "names via `tensor = tensor.rename(None)` before serialization."
- )
- def build_dim_map(tensor):
- """Returns a map of { dim: dim_name } where dim is a name if the dim is named
- and the dim index otherwise."""
- return OrderedDict(
- [(idx if name is None else name, name) for idx, name in enumerate(tensor.names)]
- )
- def unzip_namedshape(namedshape):
- if isinstance(namedshape, OrderedDict):
- namedshape = namedshape.items()
- if not hasattr(namedshape, "__iter__") and not isinstance(namedshape, tuple):
- raise RuntimeError(
- "Expected namedshape to be OrderedDict or iterable of tuples, got: {}".format(
- type(namedshape)
- )
- )
- if len(namedshape) == 0:
- raise RuntimeError("Expected namedshape to non-empty.")
- return zip(*namedshape)
- def namer_api_name(inplace):
- if inplace:
- return "rename_"
- else:
- return "rename"
- def is_ellipsis(item):
- return item == Ellipsis or item == "..."
- def single_ellipsis_index(names, fn_name):
- ellipsis_indices = [i for i, name in enumerate(names) if is_ellipsis(name)]
- if len(ellipsis_indices) >= 2:
- raise RuntimeError(
- "{}: More than one Ellipsis ('...') found in names ("
- "{}). This function supports up to one Ellipsis.".format(fn_name, names)
- )
- if len(ellipsis_indices) == 1:
- return ellipsis_indices[0]
- return None
- def expand_single_ellipsis(numel_pre_glob, numel_post_glob, names):
- return names[numel_pre_glob : len(names) - numel_post_glob]
- def replace_ellipsis_by_position(ellipsis_idx, names, tensor_names):
- globbed_names = expand_single_ellipsis(
- ellipsis_idx, len(names) - ellipsis_idx - 1, tensor_names
- )
- return names[:ellipsis_idx] + globbed_names + names[ellipsis_idx + 1 :]
- def resolve_ellipsis(names, tensor_names, fn_name):
- """
- Expands ... inside `names` to be equal to a list of names from `tensor_names`.
- """
- ellipsis_idx = single_ellipsis_index(names, fn_name)
- if ellipsis_idx is None:
- return names
- return replace_ellipsis_by_position(ellipsis_idx, names, tensor_names)
- def update_names_with_list(tensor, names, inplace):
- # Special case for tensor.rename(None)
- if len(names) == 1 and names[0] is None:
- return tensor._update_names(None, inplace)
- return tensor._update_names(
- resolve_ellipsis(names, tensor.names, namer_api_name(inplace)), inplace
- )
- def update_names_with_mapping(tensor, rename_map, inplace):
- dim_map = build_dim_map(tensor)
- for old_dim in rename_map.keys():
- new_dim = rename_map[old_dim]
- if old_dim in dim_map.keys():
- dim_map[old_dim] = new_dim
- else:
- raise RuntimeError(
- (
- "{api_name}: Tried to rename dim '{old_dim}' to dim "
- "{new_dim} in Tensor[{dims}] but dim '{old_dim}' does not exist"
- ).format(
- old_dim=old_dim,
- new_dim=new_dim,
- dims=tensor.names,
- api_name=namer_api_name(inplace),
- )
- )
- return tensor._update_names(tuple(dim_map.values()), inplace)
- def update_names(tensor, names, rename_map, inplace):
- """There are two usages:
- tensor.rename(*names) returns a view on tensor with named dims `names`.
- `names` must be of length `tensor.dim()`; otherwise, if '...' is in `names`,
- then it is expanded greedily to be equal to the corresponding names from
- `tensor.names`.
- For example,
- ```
- >>> # xdoctest: +SKIP
- >>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
- >>> x.rename('...', 'height', 'width').names
- ('N', 'C', 'height', 'width')
- >>> # xdoctest: +SKIP
- >>> x.rename('batch', '...', 'width').names
- ('batch', 'C', 'H', 'width')
- ```
- tensor.rename(**rename_map) returns a view on tensor that has rename dims
- as specified in the mapping `rename_map`.
- For example,
- ```
- >>> # xdoctest: +SKIP
- >>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
- >>> x.rename(W='width', H='height').names
- ('N', 'C', 'height', 'width')
- ```
- Finally, tensor.rename has an in-place version called tensor.rename_.
- """
- has_names = len(names) > 0
- has_rename_pairs = bool(rename_map)
- if has_names and has_rename_pairs:
- raise RuntimeError(
- "{api_name}: This function takes either positional "
- "args or keyword args, but not both. Use tensor.{api_name}(*names) "
- "to name dims and tensor.{api_name}(**rename_map) to rename "
- "dims.".format(api_name=namer_api_name(inplace))
- )
- # Special case for tensor.rename(*[]), which is valid for a 0 dim tensor.
- if not has_names and not has_rename_pairs:
- return update_names_with_list(tensor, names, inplace)
- if has_names:
- return update_names_with_list(tensor, names, inplace)
- return update_names_with_mapping(tensor, rename_map, inplace)
|