stochastic_depth.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import torch
  2. import torch.fx
  3. from torch import nn, Tensor
  4. from ..utils import _log_api_usage_once
  5. def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True) -> Tensor:
  6. """
  7. Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
  8. <https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
  9. branches of residual architectures.
  10. Args:
  11. input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
  12. being its batch i.e. a batch with ``N`` rows.
  13. p (float): probability of the input to be zeroed.
  14. mode (str): ``"batch"`` or ``"row"``.
  15. ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
  16. randomly selected rows from the batch.
  17. training: apply stochastic depth if is ``True``. Default: ``True``
  18. Returns:
  19. Tensor[N, ...]: The randomly zeroed tensor.
  20. """
  21. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  22. _log_api_usage_once(stochastic_depth)
  23. if p < 0.0 or p > 1.0:
  24. raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
  25. if mode not in ["batch", "row"]:
  26. raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
  27. if not training or p == 0.0:
  28. return input
  29. survival_rate = 1.0 - p
  30. if mode == "row":
  31. size = [input.shape[0]] + [1] * (input.ndim - 1)
  32. else:
  33. size = [1] * input.ndim
  34. noise = torch.empty(size, dtype=input.dtype, device=input.device)
  35. noise = noise.bernoulli_(survival_rate)
  36. if survival_rate > 0.0:
  37. noise.div_(survival_rate)
  38. return input * noise
  39. torch.fx.wrap("stochastic_depth")
  40. class StochasticDepth(nn.Module):
  41. """
  42. See :func:`stochastic_depth`.
  43. """
  44. def __init__(self, p: float, mode: str) -> None:
  45. super().__init__()
  46. _log_api_usage_once(self)
  47. self.p = p
  48. self.mode = mode
  49. def forward(self, input: Tensor) -> Tensor:
  50. return stochastic_depth(input, self.p, self.mode, self.training)
  51. def __repr__(self) -> str:
  52. s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
  53. return s