12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- import PIL.Image
- import torch
- from torchvision import tv_tensors
- from torchvision.transforms.functional import pil_to_tensor, to_pil_image
- from torchvision.utils import _log_api_usage_once
- from ._utils import _get_kernel, _register_kernel_internal
- def erase(
- inpt: torch.Tensor,
- i: int,
- j: int,
- h: int,
- w: int,
- v: torch.Tensor,
- inplace: bool = False,
- ) -> torch.Tensor:
- """[BETA] See :class:`~torchvision.transforms.v2.RandomErase` for details."""
- if torch.jit.is_scripting():
- return erase_image(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
- _log_api_usage_once(erase)
- kernel = _get_kernel(erase, type(inpt))
- return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
- @_register_kernel_internal(erase, torch.Tensor)
- @_register_kernel_internal(erase, tv_tensors.Image)
- def erase_image(
- image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
- ) -> torch.Tensor:
- if not inplace:
- image = image.clone()
- image[..., i : i + h, j : j + w] = v
- return image
- @_register_kernel_internal(erase, PIL.Image.Image)
- def _erase_image_pil(
- image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
- ) -> PIL.Image.Image:
- t_img = pil_to_tensor(image)
- output = erase_image(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
- return to_pil_image(output, mode=image.mode)
- @_register_kernel_internal(erase, tv_tensors.Video)
- def erase_video(
- video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
- ) -> torch.Tensor:
- return erase_image(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
|