fold.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. # -*- coding: utf-8 -*-
  2. from .module import Module
  3. from .. import functional as F
  4. from torch import Tensor
  5. from ..common_types import _size_any_t
  6. __all__ = ['Fold', 'Unfold']
  7. class Fold(Module):
  8. r"""Combines an array of sliding local blocks into a large containing
  9. tensor.
  10. Consider a batched :attr:`input` tensor containing sliding local blocks,
  11. e.g., patches of images, of shape :math:`(N, C \times \prod(\text{kernel\_size}), L)`,
  12. where :math:`N` is batch dimension, :math:`C \times \prod(\text{kernel\_size})`
  13. is the number of values within a block (a block has :math:`\prod(\text{kernel\_size})`
  14. spatial locations each containing a :math:`C`-channeled vector), and
  15. :math:`L` is the total number of blocks. (This is exactly the
  16. same specification as the output shape of :class:`~torch.nn.Unfold`.) This
  17. operation combines these local blocks into the large :attr:`output` tensor
  18. of shape :math:`(N, C, \text{output\_size}[0], \text{output\_size}[1], \dots)`
  19. by summing the overlapping values. Similar to :class:`~torch.nn.Unfold`, the
  20. arguments must satisfy
  21. .. math::
  22. L = \prod_d \left\lfloor\frac{\text{output\_size}[d] + 2 \times \text{padding}[d] %
  23. - \text{dilation}[d] \times (\text{kernel\_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor,
  24. where :math:`d` is over all spatial dimensions.
  25. * :attr:`output_size` describes the spatial shape of the large containing
  26. tensor of the sliding local blocks. It is useful to resolve the ambiguity
  27. when multiple input shapes map to same number of sliding blocks, e.g.,
  28. with ``stride > 0``.
  29. The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify
  30. how the sliding blocks are retrieved.
  31. * :attr:`stride` controls the stride for the sliding blocks.
  32. * :attr:`padding` controls the amount of implicit zero-paddings on both
  33. sides for :attr:`padding` number of points for each dimension before
  34. reshaping.
  35. * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm.
  36. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
  37. Args:
  38. output_size (int or tuple): the shape of the spatial dimensions of the
  39. output (i.e., ``output.sizes()[2:]``)
  40. kernel_size (int or tuple): the size of the sliding blocks
  41. dilation (int or tuple, optional): a parameter that controls the
  42. stride of elements within the
  43. neighborhood. Default: 1
  44. padding (int or tuple, optional): implicit zero padding to be added on
  45. both sides of input. Default: 0
  46. stride (int or tuple): the stride of the sliding blocks in the input
  47. spatial dimensions. Default: 1
  48. * If :attr:`output_size`, :attr:`kernel_size`, :attr:`dilation`,
  49. :attr:`padding` or :attr:`stride` is an int or a tuple of length 1 then
  50. their values will be replicated across all spatial dimensions.
  51. * For the case of two output spatial dimensions this operation is sometimes
  52. called ``col2im``.
  53. .. note::
  54. :class:`~torch.nn.Fold` calculates each combined value in the resulting
  55. large tensor by summing all values from all containing blocks.
  56. :class:`~torch.nn.Unfold` extracts the values in the local blocks by
  57. copying from the large tensor. So, if the blocks overlap, they are not
  58. inverses of each other.
  59. In general, folding and unfolding operations are related as
  60. follows. Consider :class:`~torch.nn.Fold` and
  61. :class:`~torch.nn.Unfold` instances created with the same
  62. parameters:
  63. >>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...)
  64. >>> fold = nn.Fold(output_size=..., **fold_params)
  65. >>> unfold = nn.Unfold(**fold_params)
  66. Then for any (supported) ``input`` tensor the following
  67. equality holds:
  68. ::
  69. fold(unfold(input)) == divisor * input
  70. where ``divisor`` is a tensor that depends only on the shape
  71. and dtype of the ``input``:
  72. >>> # xdoctest: +SKIP
  73. >>> input_ones = torch.ones(input.shape, dtype=input.dtype)
  74. >>> divisor = fold(unfold(input_ones))
  75. When the ``divisor`` tensor contains no zero elements, then
  76. ``fold`` and ``unfold`` operations are inverses of each
  77. other (up to constant divisor).
  78. .. warning::
  79. Currently, only unbatched (3D) or batched (4D) image-like output tensors are supported.
  80. Shape:
  81. - Input: :math:`(N, C \times \prod(\text{kernel\_size}), L)` or :math:`(C \times \prod(\text{kernel\_size}), L)`
  82. - Output: :math:`(N, C, \text{output\_size}[0], \text{output\_size}[1], \dots)`
  83. or :math:`(C, \text{output\_size}[0], \text{output\_size}[1], \dots)` as described above
  84. Examples::
  85. >>> fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 2))
  86. >>> input = torch.randn(1, 3 * 2 * 2, 12)
  87. >>> output = fold(input)
  88. >>> output.size()
  89. torch.Size([1, 3, 4, 5])
  90. .. _link:
  91. https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
  92. """
  93. __constants__ = ['output_size', 'kernel_size', 'dilation', 'padding',
  94. 'stride']
  95. output_size: _size_any_t
  96. kernel_size: _size_any_t
  97. dilation: _size_any_t
  98. padding: _size_any_t
  99. stride: _size_any_t
  100. def __init__(
  101. self,
  102. output_size: _size_any_t,
  103. kernel_size: _size_any_t,
  104. dilation: _size_any_t = 1,
  105. padding: _size_any_t = 0,
  106. stride: _size_any_t = 1
  107. ) -> None:
  108. super().__init__()
  109. self.output_size = output_size
  110. self.kernel_size = kernel_size
  111. self.dilation = dilation
  112. self.padding = padding
  113. self.stride = stride
  114. def forward(self, input: Tensor) -> Tensor:
  115. return F.fold(input, self.output_size, self.kernel_size, self.dilation,
  116. self.padding, self.stride)
  117. def extra_repr(self) -> str:
  118. return 'output_size={output_size}, kernel_size={kernel_size}, ' \
  119. 'dilation={dilation}, padding={padding}, stride={stride}'.format(
  120. **self.__dict__
  121. )
  122. class Unfold(Module):
  123. r"""Extracts sliding local blocks from a batched input tensor.
  124. Consider a batched :attr:`input` tensor of shape :math:`(N, C, *)`,
  125. where :math:`N` is the batch dimension, :math:`C` is the channel dimension,
  126. and :math:`*` represent arbitrary spatial dimensions. This operation flattens
  127. each sliding :attr:`kernel_size`-sized block within the spatial dimensions
  128. of :attr:`input` into a column (i.e., last dimension) of a 3-D :attr:`output`
  129. tensor of shape :math:`(N, C \times \prod(\text{kernel\_size}), L)`, where
  130. :math:`C \times \prod(\text{kernel\_size})` is the total number of values
  131. within each block (a block has :math:`\prod(\text{kernel\_size})` spatial
  132. locations each containing a :math:`C`-channeled vector), and :math:`L` is
  133. the total number of such blocks:
  134. .. math::
  135. L = \prod_d \left\lfloor\frac{\text{spatial\_size}[d] + 2 \times \text{padding}[d] %
  136. - \text{dilation}[d] \times (\text{kernel\_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor,
  137. where :math:`\text{spatial\_size}` is formed by the spatial dimensions
  138. of :attr:`input` (:math:`*` above), and :math:`d` is over all spatial
  139. dimensions.
  140. Therefore, indexing :attr:`output` at the last dimension (column dimension)
  141. gives all values within a certain block.
  142. The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify
  143. how the sliding blocks are retrieved.
  144. * :attr:`stride` controls the stride for the sliding blocks.
  145. * :attr:`padding` controls the amount of implicit zero-paddings on both
  146. sides for :attr:`padding` number of points for each dimension before
  147. reshaping.
  148. * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm.
  149. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
  150. Args:
  151. kernel_size (int or tuple): the size of the sliding blocks
  152. dilation (int or tuple, optional): a parameter that controls the
  153. stride of elements within the
  154. neighborhood. Default: 1
  155. padding (int or tuple, optional): implicit zero padding to be added on
  156. both sides of input. Default: 0
  157. stride (int or tuple, optional): the stride of the sliding blocks in the input
  158. spatial dimensions. Default: 1
  159. * If :attr:`kernel_size`, :attr:`dilation`, :attr:`padding` or
  160. :attr:`stride` is an int or a tuple of length 1, their values will be
  161. replicated across all spatial dimensions.
  162. * For the case of two input spatial dimensions this operation is sometimes
  163. called ``im2col``.
  164. .. note::
  165. :class:`~torch.nn.Fold` calculates each combined value in the resulting
  166. large tensor by summing all values from all containing blocks.
  167. :class:`~torch.nn.Unfold` extracts the values in the local blocks by
  168. copying from the large tensor. So, if the blocks overlap, they are not
  169. inverses of each other.
  170. In general, folding and unfolding operations are related as
  171. follows. Consider :class:`~torch.nn.Fold` and
  172. :class:`~torch.nn.Unfold` instances created with the same
  173. parameters:
  174. >>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...)
  175. >>> fold = nn.Fold(output_size=..., **fold_params)
  176. >>> unfold = nn.Unfold(**fold_params)
  177. Then for any (supported) ``input`` tensor the following
  178. equality holds:
  179. ::
  180. fold(unfold(input)) == divisor * input
  181. where ``divisor`` is a tensor that depends only on the shape
  182. and dtype of the ``input``:
  183. >>> # xdoctest: +SKIP
  184. >>> input_ones = torch.ones(input.shape, dtype=input.dtype)
  185. >>> divisor = fold(unfold(input_ones))
  186. When the ``divisor`` tensor contains no zero elements, then
  187. ``fold`` and ``unfold`` operations are inverses of each
  188. other (up to constant divisor).
  189. .. warning::
  190. Currently, only 4-D input tensors (batched image-like tensors) are
  191. supported.
  192. Shape:
  193. - Input: :math:`(N, C, *)`
  194. - Output: :math:`(N, C \times \prod(\text{kernel\_size}), L)` as described above
  195. Examples::
  196. >>> unfold = nn.Unfold(kernel_size=(2, 3))
  197. >>> input = torch.randn(2, 5, 3, 4)
  198. >>> output = unfold(input)
  199. >>> # each patch contains 30 values (2x3=6 vectors, each of 5 channels)
  200. >>> # 4 blocks (2x3 kernels) in total in the 3x4 input
  201. >>> output.size()
  202. torch.Size([2, 30, 4])
  203. >>> # xdoctest: +IGNORE_WANT
  204. >>> # Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape)
  205. >>> inp = torch.randn(1, 3, 10, 12)
  206. >>> w = torch.randn(2, 3, 4, 5)
  207. >>> inp_unf = torch.nn.functional.unfold(inp, (4, 5))
  208. >>> out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
  209. >>> out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1))
  210. >>> # or equivalently (and avoiding a copy),
  211. >>> # out = out_unf.view(1, 2, 7, 8)
  212. >>> (torch.nn.functional.conv2d(inp, w) - out).abs().max()
  213. tensor(1.9073e-06)
  214. .. _link:
  215. https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
  216. """
  217. __constants__ = ['kernel_size', 'dilation', 'padding', 'stride']
  218. kernel_size: _size_any_t
  219. dilation: _size_any_t
  220. padding: _size_any_t
  221. stride: _size_any_t
  222. def __init__(
  223. self,
  224. kernel_size: _size_any_t,
  225. dilation: _size_any_t = 1,
  226. padding: _size_any_t = 0,
  227. stride: _size_any_t = 1
  228. ) -> None:
  229. super().__init__()
  230. self.kernel_size = kernel_size
  231. self.dilation = dilation
  232. self.padding = padding
  233. self.stride = stride
  234. def forward(self, input: Tensor) -> Tensor:
  235. return F.unfold(input, self.kernel_size, self.dilation,
  236. self.padding, self.stride)
  237. def extra_repr(self) -> str:
  238. return 'kernel_size={kernel_size}, dilation={dilation}, padding={padding},' \
  239. ' stride={stride}'.format(**self.__dict__)