_namedtensor_internals.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. from collections import OrderedDict
  2. """
  3. This file contains helper functions that implement experimental functionality
  4. for named tensors in python. All of these are experimental, unstable, and
  5. subject to change or deletion.
  6. """
  7. def check_serializing_named_tensor(tensor):
  8. if tensor.has_names():
  9. raise RuntimeError(
  10. "NYI: Named tensors don't support serialization. Please drop "
  11. "names via `tensor = tensor.rename(None)` before serialization."
  12. )
  13. def build_dim_map(tensor):
  14. """Returns a map of { dim: dim_name } where dim is a name if the dim is named
  15. and the dim index otherwise."""
  16. return OrderedDict(
  17. [(idx if name is None else name, name) for idx, name in enumerate(tensor.names)]
  18. )
  19. def unzip_namedshape(namedshape):
  20. if isinstance(namedshape, OrderedDict):
  21. namedshape = namedshape.items()
  22. if not hasattr(namedshape, "__iter__") and not isinstance(namedshape, tuple):
  23. raise RuntimeError(
  24. "Expected namedshape to be OrderedDict or iterable of tuples, got: {}".format(
  25. type(namedshape)
  26. )
  27. )
  28. if len(namedshape) == 0:
  29. raise RuntimeError("Expected namedshape to non-empty.")
  30. return zip(*namedshape)
  31. def namer_api_name(inplace):
  32. if inplace:
  33. return "rename_"
  34. else:
  35. return "rename"
  36. def is_ellipsis(item):
  37. return item == Ellipsis or item == "..."
  38. def single_ellipsis_index(names, fn_name):
  39. ellipsis_indices = [i for i, name in enumerate(names) if is_ellipsis(name)]
  40. if len(ellipsis_indices) >= 2:
  41. raise RuntimeError(
  42. "{}: More than one Ellipsis ('...') found in names ("
  43. "{}). This function supports up to one Ellipsis.".format(fn_name, names)
  44. )
  45. if len(ellipsis_indices) == 1:
  46. return ellipsis_indices[0]
  47. return None
  48. def expand_single_ellipsis(numel_pre_glob, numel_post_glob, names):
  49. return names[numel_pre_glob : len(names) - numel_post_glob]
  50. def replace_ellipsis_by_position(ellipsis_idx, names, tensor_names):
  51. globbed_names = expand_single_ellipsis(
  52. ellipsis_idx, len(names) - ellipsis_idx - 1, tensor_names
  53. )
  54. return names[:ellipsis_idx] + globbed_names + names[ellipsis_idx + 1 :]
  55. def resolve_ellipsis(names, tensor_names, fn_name):
  56. """
  57. Expands ... inside `names` to be equal to a list of names from `tensor_names`.
  58. """
  59. ellipsis_idx = single_ellipsis_index(names, fn_name)
  60. if ellipsis_idx is None:
  61. return names
  62. return replace_ellipsis_by_position(ellipsis_idx, names, tensor_names)
  63. def update_names_with_list(tensor, names, inplace):
  64. # Special case for tensor.rename(None)
  65. if len(names) == 1 and names[0] is None:
  66. return tensor._update_names(None, inplace)
  67. return tensor._update_names(
  68. resolve_ellipsis(names, tensor.names, namer_api_name(inplace)), inplace
  69. )
  70. def update_names_with_mapping(tensor, rename_map, inplace):
  71. dim_map = build_dim_map(tensor)
  72. for old_dim in rename_map.keys():
  73. new_dim = rename_map[old_dim]
  74. if old_dim in dim_map.keys():
  75. dim_map[old_dim] = new_dim
  76. else:
  77. raise RuntimeError(
  78. (
  79. "{api_name}: Tried to rename dim '{old_dim}' to dim "
  80. "{new_dim} in Tensor[{dims}] but dim '{old_dim}' does not exist"
  81. ).format(
  82. old_dim=old_dim,
  83. new_dim=new_dim,
  84. dims=tensor.names,
  85. api_name=namer_api_name(inplace),
  86. )
  87. )
  88. return tensor._update_names(tuple(dim_map.values()), inplace)
  89. def update_names(tensor, names, rename_map, inplace):
  90. """There are two usages:
  91. tensor.rename(*names) returns a view on tensor with named dims `names`.
  92. `names` must be of length `tensor.dim()`; otherwise, if '...' is in `names`,
  93. then it is expanded greedily to be equal to the corresponding names from
  94. `tensor.names`.
  95. For example,
  96. ```
  97. >>> # xdoctest: +SKIP
  98. >>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
  99. >>> x.rename('...', 'height', 'width').names
  100. ('N', 'C', 'height', 'width')
  101. >>> # xdoctest: +SKIP
  102. >>> x.rename('batch', '...', 'width').names
  103. ('batch', 'C', 'H', 'width')
  104. ```
  105. tensor.rename(**rename_map) returns a view on tensor that has rename dims
  106. as specified in the mapping `rename_map`.
  107. For example,
  108. ```
  109. >>> # xdoctest: +SKIP
  110. >>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
  111. >>> x.rename(W='width', H='height').names
  112. ('N', 'C', 'height', 'width')
  113. ```
  114. Finally, tensor.rename has an in-place version called tensor.rename_.
  115. """
  116. has_names = len(names) > 0
  117. has_rename_pairs = bool(rename_map)
  118. if has_names and has_rename_pairs:
  119. raise RuntimeError(
  120. "{api_name}: This function takes either positional "
  121. "args or keyword args, but not both. Use tensor.{api_name}(*names) "
  122. "to name dims and tensor.{api_name}(**rename_map) to rename "
  123. "dims.".format(api_name=namer_api_name(inplace))
  124. )
  125. # Special case for tensor.rename(*[]), which is valid for a 0 dim tensor.
  126. if not has_names and not has_rename_pairs:
  127. return update_names_with_list(tensor, names, inplace)
  128. if has_names:
  129. return update_names_with_list(tensor, names, inplace)
  130. return update_names_with_mapping(tensor, rename_map, inplace)