routeddecoder.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. from io import BufferedIOBase
  2. from typing import Any, Callable, Iterable, Iterator, Sized, Tuple
  3. from torch.utils.data.datapipes._decorator import functional_datapipe
  4. from torch.utils.data.datapipes.datapipe import IterDataPipe
  5. from torch.utils.data.datapipes.utils.common import _deprecation_warning
  6. from torch.utils.data.datapipes.utils.decoder import (
  7. Decoder,
  8. basichandlers as decoder_basichandlers,
  9. imagehandler as decoder_imagehandler,
  10. extension_extract_fn
  11. )
  12. __all__ = ["RoutedDecoderIterDataPipe", ]
  13. @functional_datapipe('routed_decode')
  14. class RoutedDecoderIterDataPipe(IterDataPipe[Tuple[str, Any]]):
  15. r"""
  16. Decodes binary streams from input DataPipe, yields pathname and decoded data
  17. in a tuple (functional name: ``routed_decode``).
  18. Args:
  19. datapipe: Iterable datapipe that provides pathname and binary stream in tuples
  20. handlers: Optional user defined decoder handlers. If ``None``, basic and image decoder
  21. handlers will be set as default. If multiple handles are provided, the priority
  22. order follows the order of handlers (the first handler has the top priority)
  23. key_fn: Function for decoder to extract key from pathname to dispatch handlers.
  24. Default is set to extract file extension from pathname
  25. Note:
  26. When ``key_fn`` is specified returning anything other than extension, the default
  27. handler will not work and users need to specify custom handler. Custom handler
  28. could use regex to determine the eligibility to handle data.
  29. """
  30. def __init__(self,
  31. datapipe: Iterable[Tuple[str, BufferedIOBase]],
  32. *handlers: Callable,
  33. key_fn: Callable = extension_extract_fn) -> None:
  34. super().__init__()
  35. self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe
  36. if not handlers:
  37. handlers = (decoder_basichandlers, decoder_imagehandler('torch'))
  38. self.decoder = Decoder(*handlers, key_fn=key_fn)
  39. _deprecation_warning(
  40. type(self).__name__,
  41. deprecation_version="1.12",
  42. removal_version="1.13",
  43. old_functional_name="routed_decode",
  44. )
  45. def add_handler(self, *handler: Callable) -> None:
  46. self.decoder.add_handler(*handler)
  47. def __iter__(self) -> Iterator[Tuple[str, Any]]:
  48. for data in self.datapipe:
  49. pathname = data[0]
  50. result = self.decoder(data)
  51. yield (pathname, result[pathname])
  52. def __len__(self) -> int:
  53. if isinstance(self.datapipe, Sized):
  54. return len(self.datapipe)
  55. raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))