dropout.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. from .module import Module
  2. from .. import functional as F
  3. from torch import Tensor
  4. __all__ = ['Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout']
  5. class _DropoutNd(Module):
  6. __constants__ = ['p', 'inplace']
  7. p: float
  8. inplace: bool
  9. def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
  10. super().__init__()
  11. if p < 0 or p > 1:
  12. raise ValueError("dropout probability has to be between 0 and 1, "
  13. "but got {}".format(p))
  14. self.p = p
  15. self.inplace = inplace
  16. def extra_repr(self) -> str:
  17. return 'p={}, inplace={}'.format(self.p, self.inplace)
  18. class Dropout(_DropoutNd):
  19. r"""During training, randomly zeroes some of the elements of the input
  20. tensor with probability :attr:`p` using samples from a Bernoulli
  21. distribution. Each channel will be zeroed out independently on every forward
  22. call.
  23. This has proven to be an effective technique for regularization and
  24. preventing the co-adaptation of neurons as described in the paper
  25. `Improving neural networks by preventing co-adaptation of feature
  26. detectors`_ .
  27. Furthermore, the outputs are scaled by a factor of :math:`\frac{1}{1-p}` during
  28. training. This means that during evaluation the module simply computes an
  29. identity function.
  30. Args:
  31. p: probability of an element to be zeroed. Default: 0.5
  32. inplace: If set to ``True``, will do this operation in-place. Default: ``False``
  33. Shape:
  34. - Input: :math:`(*)`. Input can be of any shape
  35. - Output: :math:`(*)`. Output is of the same shape as input
  36. Examples::
  37. >>> m = nn.Dropout(p=0.2)
  38. >>> input = torch.randn(20, 16)
  39. >>> output = m(input)
  40. .. _Improving neural networks by preventing co-adaptation of feature
  41. detectors: https://arxiv.org/abs/1207.0580
  42. """
  43. def forward(self, input: Tensor) -> Tensor:
  44. return F.dropout(input, self.p, self.training, self.inplace)
  45. class Dropout1d(_DropoutNd):
  46. r"""Randomly zero out entire channels (a channel is a 1D feature map,
  47. e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
  48. batched input is a 1D tensor :math:`\text{input}[i, j]`).
  49. Each channel will be zeroed out independently on every forward call with
  50. probability :attr:`p` using samples from a Bernoulli distribution.
  51. Usually the input comes from :class:`nn.Conv1d` modules.
  52. As described in the paper
  53. `Efficient Object Localization Using Convolutional Networks`_ ,
  54. if adjacent pixels within feature maps are strongly correlated
  55. (as is normally the case in early convolution layers) then i.i.d. dropout
  56. will not regularize the activations and will otherwise just result
  57. in an effective learning rate decrease.
  58. In this case, :func:`nn.Dropout1d` will help promote independence between
  59. feature maps and should be used instead.
  60. Args:
  61. p (float, optional): probability of an element to be zero-ed.
  62. inplace (bool, optional): If set to ``True``, will do this operation
  63. in-place
  64. Shape:
  65. - Input: :math:`(N, C, L)` or :math:`(C, L)`.
  66. - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input).
  67. Examples::
  68. >>> m = nn.Dropout1d(p=0.2)
  69. >>> input = torch.randn(20, 16, 32)
  70. >>> output = m(input)
  71. .. _Efficient Object Localization Using Convolutional Networks:
  72. https://arxiv.org/abs/1411.4280
  73. """
  74. def forward(self, input: Tensor) -> Tensor:
  75. return F.dropout1d(input, self.p, self.training, self.inplace)
  76. class Dropout2d(_DropoutNd):
  77. r"""Randomly zero out entire channels (a channel is a 2D feature map,
  78. e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
  79. batched input is a 2D tensor :math:`\text{input}[i, j]`).
  80. Each channel will be zeroed out independently on every forward call with
  81. probability :attr:`p` using samples from a Bernoulli distribution.
  82. Usually the input comes from :class:`nn.Conv2d` modules.
  83. As described in the paper
  84. `Efficient Object Localization Using Convolutional Networks`_ ,
  85. if adjacent pixels within feature maps are strongly correlated
  86. (as is normally the case in early convolution layers) then i.i.d. dropout
  87. will not regularize the activations and will otherwise just result
  88. in an effective learning rate decrease.
  89. In this case, :func:`nn.Dropout2d` will help promote independence between
  90. feature maps and should be used instead.
  91. Args:
  92. p (float, optional): probability of an element to be zero-ed.
  93. inplace (bool, optional): If set to ``True``, will do this operation
  94. in-place
  95. .. warning ::
  96. Due to historical reasons, this class will perform 1D channel-wise dropout
  97. for 3D inputs (as done by :class:`nn.Dropout1d`). Thus, it currently does NOT
  98. support inputs without a batch dimension of shape :math:`(C, H, W)`. This
  99. behavior will change in a future release to interpret 3D inputs as no-batch-dim
  100. inputs. To maintain the old behavior, switch to :class:`nn.Dropout1d`.
  101. Shape:
  102. - Input: :math:`(N, C, H, W)` or :math:`(N, C, L)`.
  103. - Output: :math:`(N, C, H, W)` or :math:`(N, C, L)` (same shape as input).
  104. Examples::
  105. >>> m = nn.Dropout2d(p=0.2)
  106. >>> input = torch.randn(20, 16, 32, 32)
  107. >>> output = m(input)
  108. .. _Efficient Object Localization Using Convolutional Networks:
  109. https://arxiv.org/abs/1411.4280
  110. """
  111. def forward(self, input: Tensor) -> Tensor:
  112. return F.dropout2d(input, self.p, self.training, self.inplace)
  113. class Dropout3d(_DropoutNd):
  114. r"""Randomly zero out entire channels (a channel is a 3D feature map,
  115. e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
  116. batched input is a 3D tensor :math:`\text{input}[i, j]`).
  117. Each channel will be zeroed out independently on every forward call with
  118. probability :attr:`p` using samples from a Bernoulli distribution.
  119. Usually the input comes from :class:`nn.Conv3d` modules.
  120. As described in the paper
  121. `Efficient Object Localization Using Convolutional Networks`_ ,
  122. if adjacent pixels within feature maps are strongly correlated
  123. (as is normally the case in early convolution layers) then i.i.d. dropout
  124. will not regularize the activations and will otherwise just result
  125. in an effective learning rate decrease.
  126. In this case, :func:`nn.Dropout3d` will help promote independence between
  127. feature maps and should be used instead.
  128. Args:
  129. p (float, optional): probability of an element to be zeroed.
  130. inplace (bool, optional): If set to ``True``, will do this operation
  131. in-place
  132. Shape:
  133. - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
  134. - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
  135. Examples::
  136. >>> m = nn.Dropout3d(p=0.2)
  137. >>> input = torch.randn(20, 16, 4, 32, 32)
  138. >>> output = m(input)
  139. .. _Efficient Object Localization Using Convolutional Networks:
  140. https://arxiv.org/abs/1411.4280
  141. """
  142. def forward(self, input: Tensor) -> Tensor:
  143. return F.dropout3d(input, self.p, self.training, self.inplace)
  144. class AlphaDropout(_DropoutNd):
  145. r"""Applies Alpha Dropout over the input.
  146. Alpha Dropout is a type of Dropout that maintains the self-normalizing
  147. property.
  148. For an input with zero mean and unit standard deviation, the output of
  149. Alpha Dropout maintains the original mean and standard deviation of the
  150. input.
  151. Alpha Dropout goes hand-in-hand with SELU activation function, which ensures
  152. that the outputs have zero mean and unit standard deviation.
  153. During training, it randomly masks some of the elements of the input
  154. tensor with probability *p* using samples from a bernoulli distribution.
  155. The elements to masked are randomized on every forward call, and scaled
  156. and shifted to maintain zero mean and unit standard deviation.
  157. During evaluation the module simply computes an identity function.
  158. More details can be found in the paper `Self-Normalizing Neural Networks`_ .
  159. Args:
  160. p (float): probability of an element to be dropped. Default: 0.5
  161. inplace (bool, optional): If set to ``True``, will do this operation
  162. in-place
  163. Shape:
  164. - Input: :math:`(*)`. Input can be of any shape
  165. - Output: :math:`(*)`. Output is of the same shape as input
  166. Examples::
  167. >>> m = nn.AlphaDropout(p=0.2)
  168. >>> input = torch.randn(20, 16)
  169. >>> output = m(input)
  170. .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
  171. """
  172. def forward(self, input: Tensor) -> Tensor:
  173. return F.alpha_dropout(input, self.p, self.training)
  174. class FeatureAlphaDropout(_DropoutNd):
  175. r"""Randomly masks out entire channels (a channel is a feature map,
  176. e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input
  177. is a tensor :math:`\text{input}[i, j]`) of the input tensor). Instead of
  178. setting activations to zero, as in regular Dropout, the activations are set
  179. to the negative saturation value of the SELU activation function. More details
  180. can be found in the paper `Self-Normalizing Neural Networks`_ .
  181. Each element will be masked independently for each sample on every forward
  182. call with probability :attr:`p` using samples from a Bernoulli distribution.
  183. The elements to be masked are randomized on every forward call, and scaled
  184. and shifted to maintain zero mean and unit variance.
  185. Usually the input comes from :class:`nn.AlphaDropout` modules.
  186. As described in the paper
  187. `Efficient Object Localization Using Convolutional Networks`_ ,
  188. if adjacent pixels within feature maps are strongly correlated
  189. (as is normally the case in early convolution layers) then i.i.d. dropout
  190. will not regularize the activations and will otherwise just result
  191. in an effective learning rate decrease.
  192. In this case, :func:`nn.AlphaDropout` will help promote independence between
  193. feature maps and should be used instead.
  194. Args:
  195. p (float, optional): probability of an element to be zeroed. Default: 0.5
  196. inplace (bool, optional): If set to ``True``, will do this operation
  197. in-place
  198. Shape:
  199. - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
  200. - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
  201. Examples::
  202. >>> m = nn.FeatureAlphaDropout(p=0.2)
  203. >>> input = torch.randn(20, 16, 4, 32, 32)
  204. >>> output = m(input)
  205. .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
  206. .. _Efficient Object Localization Using Convolutional Networks:
  207. https://arxiv.org/abs/1411.4280
  208. """
  209. def forward(self, input: Tensor) -> Tensor:
  210. return F.feature_alpha_dropout(input, self.p, self.training)