_misc.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. import math
  2. from typing import List, Optional
  3. import PIL.Image
  4. import torch
  5. from torch.nn.functional import conv2d, pad as torch_pad
  6. from torchvision import tv_tensors
  7. from torchvision.transforms._functional_tensor import _max_value
  8. from torchvision.transforms.functional import pil_to_tensor, to_pil_image
  9. from torchvision.utils import _log_api_usage_once
  10. from ._utils import _get_kernel, _register_kernel_internal
  11. def normalize(
  12. inpt: torch.Tensor,
  13. mean: List[float],
  14. std: List[float],
  15. inplace: bool = False,
  16. ) -> torch.Tensor:
  17. """[BETA] See :class:`~torchvision.transforms.v2.Normalize` for details."""
  18. if torch.jit.is_scripting():
  19. return normalize_image(inpt, mean=mean, std=std, inplace=inplace)
  20. _log_api_usage_once(normalize)
  21. kernel = _get_kernel(normalize, type(inpt))
  22. return kernel(inpt, mean=mean, std=std, inplace=inplace)
  23. @_register_kernel_internal(normalize, torch.Tensor)
  24. @_register_kernel_internal(normalize, tv_tensors.Image)
  25. def normalize_image(image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
  26. if not image.is_floating_point():
  27. raise TypeError(f"Input tensor should be a float tensor. Got {image.dtype}.")
  28. if image.ndim < 3:
  29. raise ValueError(f"Expected tensor to be a tensor image of size (..., C, H, W). Got {image.shape}.")
  30. if isinstance(std, (tuple, list)):
  31. divzero = not all(std)
  32. elif isinstance(std, (int, float)):
  33. divzero = std == 0
  34. else:
  35. divzero = False
  36. if divzero:
  37. raise ValueError("std evaluated to zero, leading to division by zero.")
  38. dtype = image.dtype
  39. device = image.device
  40. mean = torch.as_tensor(mean, dtype=dtype, device=device)
  41. std = torch.as_tensor(std, dtype=dtype, device=device)
  42. if mean.ndim == 1:
  43. mean = mean.view(-1, 1, 1)
  44. if std.ndim == 1:
  45. std = std.view(-1, 1, 1)
  46. if inplace:
  47. image = image.sub_(mean)
  48. else:
  49. image = image.sub(mean)
  50. return image.div_(std)
  51. @_register_kernel_internal(normalize, tv_tensors.Video)
  52. def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
  53. return normalize_image(video, mean, std, inplace=inplace)
  54. def gaussian_blur(inpt: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> torch.Tensor:
  55. """[BETA] See :class:`~torchvision.transforms.v2.GaussianBlur` for details."""
  56. if torch.jit.is_scripting():
  57. return gaussian_blur_image(inpt, kernel_size=kernel_size, sigma=sigma)
  58. _log_api_usage_once(gaussian_blur)
  59. kernel = _get_kernel(gaussian_blur, type(inpt))
  60. return kernel(inpt, kernel_size=kernel_size, sigma=sigma)
  61. def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
  62. lim = (kernel_size - 1) / (2.0 * math.sqrt(2.0) * sigma)
  63. x = torch.linspace(-lim, lim, steps=kernel_size, dtype=dtype, device=device)
  64. kernel1d = torch.softmax(x.pow_(2).neg_(), dim=0)
  65. return kernel1d
  66. def _get_gaussian_kernel2d(
  67. kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
  68. ) -> torch.Tensor:
  69. kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0], dtype, device)
  70. kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1], dtype, device)
  71. kernel2d = kernel1d_y.unsqueeze(-1) * kernel1d_x
  72. return kernel2d
  73. @_register_kernel_internal(gaussian_blur, torch.Tensor)
  74. @_register_kernel_internal(gaussian_blur, tv_tensors.Image)
  75. def gaussian_blur_image(
  76. image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
  77. ) -> torch.Tensor:
  78. # TODO: consider deprecating integers from sigma on the future
  79. if isinstance(kernel_size, int):
  80. kernel_size = [kernel_size, kernel_size]
  81. elif len(kernel_size) != 2:
  82. raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}")
  83. for ksize in kernel_size:
  84. if ksize % 2 == 0 or ksize < 0:
  85. raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}")
  86. if sigma is None:
  87. sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]
  88. else:
  89. if isinstance(sigma, (list, tuple)):
  90. length = len(sigma)
  91. if length == 1:
  92. s = float(sigma[0])
  93. sigma = [s, s]
  94. elif length != 2:
  95. raise ValueError(f"If sigma is a sequence, its length should be 2. Got {length}")
  96. elif isinstance(sigma, (int, float)):
  97. s = float(sigma)
  98. sigma = [s, s]
  99. else:
  100. raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
  101. for s in sigma:
  102. if s <= 0.0:
  103. raise ValueError(f"sigma should have positive values. Got {sigma}")
  104. if image.numel() == 0:
  105. return image
  106. dtype = image.dtype
  107. shape = image.shape
  108. ndim = image.ndim
  109. if ndim == 3:
  110. image = image.unsqueeze(dim=0)
  111. elif ndim > 4:
  112. image = image.reshape((-1,) + shape[-3:])
  113. fp = torch.is_floating_point(image)
  114. kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype if fp else torch.float32, device=image.device)
  115. kernel = kernel.expand(shape[-3], 1, kernel.shape[0], kernel.shape[1])
  116. output = image if fp else image.to(dtype=torch.float32)
  117. # padding = (left, right, top, bottom)
  118. padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
  119. output = torch_pad(output, padding, mode="reflect")
  120. output = conv2d(output, kernel, groups=shape[-3])
  121. if ndim == 3:
  122. output = output.squeeze(dim=0)
  123. elif ndim > 4:
  124. output = output.reshape(shape)
  125. if not fp:
  126. output = output.round_().to(dtype=dtype)
  127. return output
  128. @_register_kernel_internal(gaussian_blur, PIL.Image.Image)
  129. def _gaussian_blur_image_pil(
  130. image: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None
  131. ) -> PIL.Image.Image:
  132. t_img = pil_to_tensor(image)
  133. output = gaussian_blur_image(t_img, kernel_size=kernel_size, sigma=sigma)
  134. return to_pil_image(output, mode=image.mode)
  135. @_register_kernel_internal(gaussian_blur, tv_tensors.Video)
  136. def gaussian_blur_video(
  137. video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
  138. ) -> torch.Tensor:
  139. return gaussian_blur_image(video, kernel_size, sigma)
  140. def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
  141. """[BETA] See :func:`~torchvision.transforms.v2.ToDtype` for details."""
  142. if torch.jit.is_scripting():
  143. return to_dtype_image(inpt, dtype=dtype, scale=scale)
  144. _log_api_usage_once(to_dtype)
  145. kernel = _get_kernel(to_dtype, type(inpt))
  146. return kernel(inpt, dtype=dtype, scale=scale)
  147. def _num_value_bits(dtype: torch.dtype) -> int:
  148. if dtype == torch.uint8:
  149. return 8
  150. elif dtype == torch.int8:
  151. return 7
  152. elif dtype == torch.int16:
  153. return 15
  154. elif dtype == torch.int32:
  155. return 31
  156. elif dtype == torch.int64:
  157. return 63
  158. else:
  159. raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")
  160. @_register_kernel_internal(to_dtype, torch.Tensor)
  161. @_register_kernel_internal(to_dtype, tv_tensors.Image)
  162. def to_dtype_image(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
  163. if image.dtype == dtype:
  164. return image
  165. elif not scale:
  166. return image.to(dtype)
  167. float_input = image.is_floating_point()
  168. if torch.jit.is_scripting():
  169. # TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT
  170. float_output = torch.tensor(0, dtype=dtype).is_floating_point()
  171. else:
  172. float_output = dtype.is_floating_point
  173. if float_input:
  174. # float to float
  175. if float_output:
  176. return image.to(dtype)
  177. # float to int
  178. if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
  179. image.dtype == torch.float64 and dtype == torch.int64
  180. ):
  181. raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.")
  182. # For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting
  183. # to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only
  184. # be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
  185. # for a detailed analysis.
  186. # To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation.
  187. # Instead, we can also multiply by the maximum value plus something close to `1`. See
  188. # https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
  189. eps = 1e-3
  190. max_value = float(_max_value(dtype))
  191. # We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
  192. # discrete set `{0, 1}`.
  193. return image.mul(max_value + 1.0 - eps).to(dtype)
  194. else:
  195. # int to float
  196. if float_output:
  197. return image.to(dtype).mul_(1.0 / _max_value(image.dtype))
  198. # int to int
  199. num_value_bits_input = _num_value_bits(image.dtype)
  200. num_value_bits_output = _num_value_bits(dtype)
  201. if num_value_bits_input > num_value_bits_output:
  202. return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype)
  203. else:
  204. return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)
  205. # We encourage users to use to_dtype() instead but we keep this for BC
  206. def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
  207. """[BETA] [DEPRECATED] Use to_dtype() instead."""
  208. return to_dtype_image(image, dtype=dtype, scale=True)
  209. @_register_kernel_internal(to_dtype, tv_tensors.Video)
  210. def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
  211. return to_dtype_image(video, dtype, scale=scale)
  212. @_register_kernel_internal(to_dtype, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
  213. @_register_kernel_internal(to_dtype, tv_tensors.Mask, tv_tensor_wrapper=False)
  214. def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: bool = False) -> torch.Tensor:
  215. # We don't need to unwrap and rewrap here, since TVTensor.to() preserves the type
  216. return inpt.to(dtype)