_utils.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import dataclasses
  2. import traceback
  3. from collections import OrderedDict
  4. from typing import Any, Callable, cast, Dict, List, Set, Tuple, Union
  5. import torch
  6. from torch.nn.modules.batchnorm import _BatchNorm
  7. from torch.nn.parallel.scatter_gather import ( # type: ignore[attr-defined]
  8. _is_namedtuple,
  9. )
  10. from torch.nn.utils.rnn import PackedSequence
  11. from torch.utils._mode_utils import no_dispatch
  12. def _contains_batchnorm(module):
  13. return any(isinstance(mod, _BatchNorm) for mod in module.modules())
  14. def _override_batchnorm_mixed_precision(module):
  15. for mod in module.modules():
  16. if isinstance(mod, _BatchNorm):
  17. mod._wrap_overrides = {"mixed_precision": None} # type: ignore[assignment]
  18. def _apply_to_tensors(
  19. fn: Callable,
  20. container: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence],
  21. ) -> Any:
  22. """Recursively apply to all tensor in different kinds of container types."""
  23. def apply(
  24. x: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence]
  25. ) -> Any:
  26. if torch.is_tensor(x):
  27. return fn(x)
  28. elif hasattr(x, "__dataclass_fields__"):
  29. dc = dataclasses.replace(x)
  30. for f in dataclasses.fields(dc):
  31. name = f.name
  32. setattr(dc, name, apply(getattr(dc, name)))
  33. return dc
  34. elif isinstance(x, OrderedDict):
  35. od = x.__class__()
  36. for key, value in x.items():
  37. od[key] = apply(value)
  38. return od
  39. elif isinstance(x, PackedSequence):
  40. apply(x.data)
  41. return x
  42. elif isinstance(x, dict):
  43. return {key: apply(value) for key, value in x.items()}
  44. elif _is_namedtuple(x):
  45. res = (apply(el) for el in x)
  46. return type(x)(*res)
  47. elif isinstance(x, (list, tuple, set)):
  48. return type(x)(apply(el) for el in x)
  49. else:
  50. return x
  51. return apply(container)
  52. @torch.no_grad()
  53. def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> bool:
  54. """
  55. Allocate storage for ``tensor`` with the given size.
  56. Returns:
  57. bool: ``True`` if this method allocated storage and ``False`` if the
  58. storage was already allocated.
  59. """
  60. already_allocated = tensor._typed_storage()._size() == size.numel()
  61. if not already_allocated:
  62. tensor_storage_size = tensor._typed_storage()._size()
  63. p_assert(
  64. tensor_storage_size == 0,
  65. f"Tensor storage should have been resized to be 0 but got {tensor_storage_size}",
  66. )
  67. tensor._typed_storage()._resize_(size.numel())
  68. return not already_allocated
  69. @torch.no_grad()
  70. def _free_storage(tensor: torch.Tensor) -> bool:
  71. """
  72. Frees the underlying storage of ``tensor``.
  73. Returns:
  74. bool: ``True`` if the method freed the storage and ``False`` if the
  75. storage was already freed.
  76. """
  77. already_freed = tensor._typed_storage()._size() == 0
  78. if not already_freed:
  79. p_assert(
  80. tensor.storage_offset() == 0,
  81. "Freeing a tensor's storage is unsafe when it is not the sole occupant\n"
  82. f"storage offset: {tensor.storage_offset()}\n"
  83. f"storage size: {tensor._typed_storage()._size()}\n"
  84. f"tensor shape: {tensor.shape}",
  85. )
  86. tensor._typed_storage()._resize_(0)
  87. return not already_freed
  88. def _same_storage(x: torch.Tensor, y: torch.Tensor) -> bool:
  89. """Returns if ``x`` and ``y`` share the same storage."""
  90. # NOTE: CPU and GPU tensors are ensured to have different data pointers.
  91. return x._typed_storage()._data_ptr() == y._typed_storage()._data_ptr()
  92. def p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None:
  93. """This is used as an alternate to ``assert`` when in the backward context
  94. to print the error message ``s`` since otherwise, it is swallowed."""
  95. if not cond:
  96. print(s)
  97. traceback.print_stack()
  98. if raise_assertion_error:
  99. raise AssertionError(s)
  100. def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.cuda.Stream) -> None:
  101. with no_dispatch():
  102. tensor.record_stream(cast(torch._C.Stream, stream))