batchnorm.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # Copyright 2019 Kakao Brain
  2. #
  3. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
  4. #
  5. # This source code is licensed under the BSD license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. """Tracks the running statistics per mini-batch instead of micro-batch."""
  8. from typing import TypeVar, cast
  9. import torch
  10. from torch import Tensor, nn
  11. from torch.nn.functional import batch_norm
  12. from torch.nn.modules.batchnorm import _BatchNorm
  13. from .checkpoint import is_recomputing
  14. __all__ = ["DeferredBatchNorm"]
  15. TModule = TypeVar("TModule", bound=nn.Module)
  16. class DeferredBatchNorm(_BatchNorm):
  17. """A BatchNorm layer tracks multiple micro-batches to update running
  18. statistics per mini-batch.
  19. """
  20. sum: Tensor
  21. sum_squares: Tensor
  22. running_mean: Tensor
  23. running_var: Tensor
  24. num_batches_tracked: Tensor
  25. def __init__(
  26. self,
  27. num_features: int,
  28. eps: float = 1e-5,
  29. momentum: float = 0.1,
  30. affine: bool = True,
  31. chunks: int = 1,
  32. ) -> None:
  33. super().__init__(num_features, eps, momentum, affine, track_running_stats=True)
  34. self.register_buffer("sum", torch.zeros_like(self.running_mean))
  35. self.register_buffer("sum_squares", torch.zeros_like(self.running_var))
  36. self.counter = 0
  37. self.tracked = 0
  38. self.chunks = chunks
  39. def _check_input_dim(self, input: Tensor) -> None:
  40. # It's the typical _check_input_dim() implementation in PyTorch.
  41. if input.dim() <= 2:
  42. raise ValueError("expected at least 3D input (got %dD input)" % input.dim())
  43. def _track(self, input: Tensor) -> bool:
  44. """Tracks statistics of a micro-batch."""
  45. # Dimensions except channel. For example, (0, 2, 3) is for BatchNorm2d.
  46. dim = [0]
  47. dim.extend(range(2, input.dim()))
  48. with torch.no_grad():
  49. self.sum += input.sum(dim)
  50. self.sum_squares += (input ** 2).sum(dim)
  51. size = input.size().numel() // input.size(1)
  52. self.counter += size
  53. self.tracked += 1
  54. return self.tracked == self.chunks
  55. def _commit(self) -> None:
  56. """Updates the running statistics of a mini-batch."""
  57. exponential_average_factor = 0.0
  58. self.num_batches_tracked += 1
  59. if self.momentum is None: # use cumulative moving average
  60. exponential_average_factor = 1.0 / float(self.num_batches_tracked)
  61. else: # use exponential moving average
  62. exponential_average_factor = self.momentum
  63. mean = self.sum / self.counter
  64. var = self.sum_squares / self.counter - mean ** 2
  65. # Calculate the exponential moving average here.
  66. m = exponential_average_factor
  67. self.running_mean *= 1 - m
  68. self.running_mean += mean * m
  69. self.running_var *= 1 - m
  70. self.running_var += var * m
  71. self.sum.zero_()
  72. self.sum_squares.zero_()
  73. self.counter = 0
  74. self.tracked = 0
  75. def forward(self, input: Tensor) -> Tensor:
  76. if not self.training:
  77. # Don't train parameters on the evaluation mode.
  78. return batch_norm(
  79. input,
  80. running_mean=self.running_mean,
  81. running_var=self.running_var,
  82. weight=self.weight,
  83. bias=self.bias,
  84. training=False,
  85. momentum=0.0,
  86. eps=self.eps,
  87. )
  88. if not is_recomputing():
  89. # Track a micro-batch on the training mode
  90. # but not under a recomputation.
  91. tracked_enough = self._track(input)
  92. # Update the running statistics for a mini-batch
  93. # if it has tracked enough micro-batches.
  94. if tracked_enough:
  95. self._commit()
  96. # Normalize a micro-batch and train the parameters.
  97. return batch_norm(
  98. input,
  99. running_mean=None,
  100. running_var=None,
  101. weight=self.weight,
  102. bias=self.bias,
  103. training=True,
  104. momentum=0.0,
  105. eps=self.eps,
  106. )
  107. @classmethod
  108. def convert_deferred_batch_norm(cls, module: TModule, chunks: int = 1) -> TModule:
  109. """Converts a :class:`nn.BatchNorm` or underlying
  110. :class:`nn.BatchNorm`s into :class:`DeferredBatchNorm`::
  111. from torchvision.models.resnet import resnet101
  112. from torchpipe.batchnorm import DeferredBatchNorm
  113. model = resnet101()
  114. model = DeferredBatchNorm.convert_deferred_batch_norm(model)
  115. """
  116. if isinstance(module, DeferredBatchNorm) and module.chunks is chunks:
  117. return cast(TModule, module)
  118. module_output: nn.Module = module
  119. if isinstance(module, _BatchNorm) and module.track_running_stats:
  120. module_output = DeferredBatchNorm(module.num_features, module.eps, module.momentum, module.affine, chunks)
  121. if module.affine:
  122. module_output.register_parameter("weight", module.weight)
  123. module_output.register_parameter("bias", module.bias)
  124. module_output.register_buffer("running_mean", module.running_mean)
  125. module_output.register_buffer("running_var", module.running_var)
  126. module_output.register_buffer("num_batches_tracked", module.num_batches_tracked)
  127. for name, child in module.named_children():
  128. module_output.add_module(name, cls.convert_deferred_batch_norm(child, chunks))
  129. return cast(TModule, module_output)