import dataclasses import traceback from collections import OrderedDict from typing import Any, Callable, cast, Dict, List, Set, Tuple, Union import torch from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.parallel.scatter_gather import ( # type: ignore[attr-defined] _is_namedtuple, ) from torch.nn.utils.rnn import PackedSequence from torch.utils._mode_utils import no_dispatch def _contains_batchnorm(module): return any(isinstance(mod, _BatchNorm) for mod in module.modules()) def _override_batchnorm_mixed_precision(module): for mod in module.modules(): if isinstance(mod, _BatchNorm): mod._wrap_overrides = {"mixed_precision": None} # type: ignore[assignment] def _apply_to_tensors( fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence], ) -> Any: """Recursively apply to all tensor in different kinds of container types.""" def apply( x: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence] ) -> Any: if torch.is_tensor(x): return fn(x) elif hasattr(x, "__dataclass_fields__"): dc = dataclasses.replace(x) for f in dataclasses.fields(dc): name = f.name setattr(dc, name, apply(getattr(dc, name))) return dc elif isinstance(x, OrderedDict): od = x.__class__() for key, value in x.items(): od[key] = apply(value) return od elif isinstance(x, PackedSequence): apply(x.data) return x elif isinstance(x, dict): return {key: apply(value) for key, value in x.items()} elif _is_namedtuple(x): res = (apply(el) for el in x) return type(x)(*res) elif isinstance(x, (list, tuple, set)): return type(x)(apply(el) for el in x) else: return x return apply(container) @torch.no_grad() def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> bool: """ Allocate storage for ``tensor`` with the given size. Returns: bool: ``True`` if this method allocated storage and ``False`` if the storage was already allocated. """ already_allocated = tensor._typed_storage()._size() == size.numel() if not already_allocated: tensor_storage_size = tensor._typed_storage()._size() p_assert( tensor_storage_size == 0, f"Tensor storage should have been resized to be 0 but got {tensor_storage_size}", ) tensor._typed_storage()._resize_(size.numel()) return not already_allocated @torch.no_grad() def _free_storage(tensor: torch.Tensor) -> bool: """ Frees the underlying storage of ``tensor``. Returns: bool: ``True`` if the method freed the storage and ``False`` if the storage was already freed. """ already_freed = tensor._typed_storage()._size() == 0 if not already_freed: p_assert( tensor.storage_offset() == 0, "Freeing a tensor's storage is unsafe when it is not the sole occupant\n" f"storage offset: {tensor.storage_offset()}\n" f"storage size: {tensor._typed_storage()._size()}\n" f"tensor shape: {tensor.shape}", ) tensor._typed_storage()._resize_(0) return not already_freed def _same_storage(x: torch.Tensor, y: torch.Tensor) -> bool: """Returns if ``x`` and ``y`` share the same storage.""" # NOTE: CPU and GPU tensors are ensured to have different data pointers. return x._typed_storage()._data_ptr() == y._typed_storage()._data_ptr() def p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None: """This is used as an alternate to ``assert`` when in the backward context to print the error message ``s`` since otherwise, it is swallowed.""" if not cond: print(s) traceback.print_stack() if raise_assertion_error: raise AssertionError(s) def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.cuda.Stream) -> None: with no_dispatch(): tensor.record_stream(cast(torch._C.Stream, stream))