drop_block.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import torch
  2. import torch.fx
  3. import torch.nn.functional as F
  4. from torch import nn, Tensor
  5. from ..utils import _log_api_usage_once
  6. def drop_block2d(
  7. input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, training: bool = True
  8. ) -> Tensor:
  9. """
  10. Implements DropBlock2d from `"DropBlock: A regularization method for convolutional networks"
  11. <https://arxiv.org/abs/1810.12890>`.
  12. Args:
  13. input (Tensor[N, C, H, W]): The input tensor or 4-dimensions with the first one
  14. being its batch i.e. a batch with ``N`` rows.
  15. p (float): Probability of an element to be dropped.
  16. block_size (int): Size of the block to drop.
  17. inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False``.
  18. eps (float): A value added to the denominator for numerical stability. Default: 1e-6.
  19. training (bool): apply dropblock if is ``True``. Default: ``True``.
  20. Returns:
  21. Tensor[N, C, H, W]: The randomly zeroed tensor after dropblock.
  22. """
  23. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  24. _log_api_usage_once(drop_block2d)
  25. if p < 0.0 or p > 1.0:
  26. raise ValueError(f"drop probability has to be between 0 and 1, but got {p}.")
  27. if input.ndim != 4:
  28. raise ValueError(f"input should be 4 dimensional. Got {input.ndim} dimensions.")
  29. if not training or p == 0.0:
  30. return input
  31. N, C, H, W = input.size()
  32. block_size = min(block_size, W, H)
  33. # compute the gamma of Bernoulli distribution
  34. gamma = (p * H * W) / ((block_size**2) * ((H - block_size + 1) * (W - block_size + 1)))
  35. noise = torch.empty((N, C, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device)
  36. noise.bernoulli_(gamma)
  37. noise = F.pad(noise, [block_size // 2] * 4, value=0)
  38. noise = F.max_pool2d(noise, stride=(1, 1), kernel_size=(block_size, block_size), padding=block_size // 2)
  39. noise = 1 - noise
  40. normalize_scale = noise.numel() / (eps + noise.sum())
  41. if inplace:
  42. input.mul_(noise).mul_(normalize_scale)
  43. else:
  44. input = input * noise * normalize_scale
  45. return input
  46. def drop_block3d(
  47. input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, training: bool = True
  48. ) -> Tensor:
  49. """
  50. Implements DropBlock3d from `"DropBlock: A regularization method for convolutional networks"
  51. <https://arxiv.org/abs/1810.12890>`.
  52. Args:
  53. input (Tensor[N, C, D, H, W]): The input tensor or 5-dimensions with the first one
  54. being its batch i.e. a batch with ``N`` rows.
  55. p (float): Probability of an element to be dropped.
  56. block_size (int): Size of the block to drop.
  57. inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False``.
  58. eps (float): A value added to the denominator for numerical stability. Default: 1e-6.
  59. training (bool): apply dropblock if is ``True``. Default: ``True``.
  60. Returns:
  61. Tensor[N, C, D, H, W]: The randomly zeroed tensor after dropblock.
  62. """
  63. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  64. _log_api_usage_once(drop_block3d)
  65. if p < 0.0 or p > 1.0:
  66. raise ValueError(f"drop probability has to be between 0 and 1, but got {p}.")
  67. if input.ndim != 5:
  68. raise ValueError(f"input should be 5 dimensional. Got {input.ndim} dimensions.")
  69. if not training or p == 0.0:
  70. return input
  71. N, C, D, H, W = input.size()
  72. block_size = min(block_size, D, H, W)
  73. # compute the gamma of Bernoulli distribution
  74. gamma = (p * D * H * W) / ((block_size**3) * ((D - block_size + 1) * (H - block_size + 1) * (W - block_size + 1)))
  75. noise = torch.empty(
  76. (N, C, D - block_size + 1, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device
  77. )
  78. noise.bernoulli_(gamma)
  79. noise = F.pad(noise, [block_size // 2] * 6, value=0)
  80. noise = F.max_pool3d(
  81. noise, stride=(1, 1, 1), kernel_size=(block_size, block_size, block_size), padding=block_size // 2
  82. )
  83. noise = 1 - noise
  84. normalize_scale = noise.numel() / (eps + noise.sum())
  85. if inplace:
  86. input.mul_(noise).mul_(normalize_scale)
  87. else:
  88. input = input * noise * normalize_scale
  89. return input
  90. torch.fx.wrap("drop_block2d")
  91. class DropBlock2d(nn.Module):
  92. """
  93. See :func:`drop_block2d`.
  94. """
  95. def __init__(self, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06) -> None:
  96. super().__init__()
  97. self.p = p
  98. self.block_size = block_size
  99. self.inplace = inplace
  100. self.eps = eps
  101. def forward(self, input: Tensor) -> Tensor:
  102. """
  103. Args:
  104. input (Tensor): Input feature map on which some areas will be randomly
  105. dropped.
  106. Returns:
  107. Tensor: The tensor after DropBlock layer.
  108. """
  109. return drop_block2d(input, self.p, self.block_size, self.inplace, self.eps, self.training)
  110. def __repr__(self) -> str:
  111. s = f"{self.__class__.__name__}(p={self.p}, block_size={self.block_size}, inplace={self.inplace})"
  112. return s
  113. torch.fx.wrap("drop_block3d")
  114. class DropBlock3d(DropBlock2d):
  115. """
  116. See :func:`drop_block3d`.
  117. """
  118. def __init__(self, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06) -> None:
  119. super().__init__(p, block_size, inplace, eps)
  120. def forward(self, input: Tensor) -> Tensor:
  121. """
  122. Args:
  123. input (Tensor): Input feature map on which some areas will be randomly
  124. dropped.
  125. Returns:
  126. Tensor: The tensor after DropBlock layer.
  127. """
  128. return drop_block3d(input, self.p, self.block_size, self.inplace, self.eps, self.training)