_augment.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import PIL.Image
  2. import torch
  3. from torchvision import tv_tensors
  4. from torchvision.transforms.functional import pil_to_tensor, to_pil_image
  5. from torchvision.utils import _log_api_usage_once
  6. from ._utils import _get_kernel, _register_kernel_internal
  7. def erase(
  8. inpt: torch.Tensor,
  9. i: int,
  10. j: int,
  11. h: int,
  12. w: int,
  13. v: torch.Tensor,
  14. inplace: bool = False,
  15. ) -> torch.Tensor:
  16. """[BETA] See :class:`~torchvision.transforms.v2.RandomErase` for details."""
  17. if torch.jit.is_scripting():
  18. return erase_image(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
  19. _log_api_usage_once(erase)
  20. kernel = _get_kernel(erase, type(inpt))
  21. return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
  22. @_register_kernel_internal(erase, torch.Tensor)
  23. @_register_kernel_internal(erase, tv_tensors.Image)
  24. def erase_image(
  25. image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
  26. ) -> torch.Tensor:
  27. if not inplace:
  28. image = image.clone()
  29. image[..., i : i + h, j : j + w] = v
  30. return image
  31. @_register_kernel_internal(erase, PIL.Image.Image)
  32. def _erase_image_pil(
  33. image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
  34. ) -> PIL.Image.Image:
  35. t_img = pil_to_tensor(image)
  36. output = erase_image(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
  37. return to_pil_image(output, mode=image.mode)
  38. @_register_kernel_internal(erase, tv_tensors.Video)
  39. def erase_video(
  40. video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
  41. ) -> torch.Tensor:
  42. return erase_image(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace)