channelshuffle.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. from .module import Module
  2. from .. import functional as F
  3. from torch import Tensor
  4. __all__ = ['ChannelShuffle']
  5. class ChannelShuffle(Module):
  6. r"""Divide the channels in a tensor of shape :math:`(*, C , H, W)`
  7. into g groups and rearrange them as :math:`(*, C \frac g, g, H, W)`,
  8. while keeping the original tensor shape.
  9. Args:
  10. groups (int): number of groups to divide channels in.
  11. Examples::
  12. >>> # xdoctest: +IGNORE_WANT("FIXME: incorrect want")
  13. >>> channel_shuffle = nn.ChannelShuffle(2)
  14. >>> input = torch.randn(1, 4, 2, 2)
  15. >>> print(input)
  16. [[[[1, 2],
  17. [3, 4]],
  18. [[5, 6],
  19. [7, 8]],
  20. [[9, 10],
  21. [11, 12]],
  22. [[13, 14],
  23. [15, 16]],
  24. ]]
  25. >>> output = channel_shuffle(input)
  26. >>> print(output)
  27. [[[[1, 2],
  28. [3, 4]],
  29. [[9, 10],
  30. [11, 12]],
  31. [[5, 6],
  32. [7, 8]],
  33. [[13, 14],
  34. [15, 16]],
  35. ]]
  36. """
  37. __constants__ = ['groups']
  38. groups: int
  39. def __init__(self, groups: int) -> None:
  40. super().__init__()
  41. self.groups = groups
  42. def forward(self, input: Tensor) -> Tensor:
  43. return F.channel_shuffle(input, self.groups)
  44. def extra_repr(self) -> str:
  45. return 'groups={}'.format(self.groups)