1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465 |
- from io import BufferedIOBase
- from typing import Any, Callable, Iterable, Iterator, Sized, Tuple
- from torch.utils.data.datapipes._decorator import functional_datapipe
- from torch.utils.data.datapipes.datapipe import IterDataPipe
- from torch.utils.data.datapipes.utils.common import _deprecation_warning
- from torch.utils.data.datapipes.utils.decoder import (
- Decoder,
- basichandlers as decoder_basichandlers,
- imagehandler as decoder_imagehandler,
- extension_extract_fn
- )
- __all__ = ["RoutedDecoderIterDataPipe", ]
- @functional_datapipe('routed_decode')
- class RoutedDecoderIterDataPipe(IterDataPipe[Tuple[str, Any]]):
- r"""
- Decodes binary streams from input DataPipe, yields pathname and decoded data
- in a tuple (functional name: ``routed_decode``).
- Args:
- datapipe: Iterable datapipe that provides pathname and binary stream in tuples
- handlers: Optional user defined decoder handlers. If ``None``, basic and image decoder
- handlers will be set as default. If multiple handles are provided, the priority
- order follows the order of handlers (the first handler has the top priority)
- key_fn: Function for decoder to extract key from pathname to dispatch handlers.
- Default is set to extract file extension from pathname
- Note:
- When ``key_fn`` is specified returning anything other than extension, the default
- handler will not work and users need to specify custom handler. Custom handler
- could use regex to determine the eligibility to handle data.
- """
- def __init__(self,
- datapipe: Iterable[Tuple[str, BufferedIOBase]],
- *handlers: Callable,
- key_fn: Callable = extension_extract_fn) -> None:
- super().__init__()
- self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe
- if not handlers:
- handlers = (decoder_basichandlers, decoder_imagehandler('torch'))
- self.decoder = Decoder(*handlers, key_fn=key_fn)
- _deprecation_warning(
- type(self).__name__,
- deprecation_version="1.12",
- removal_version="1.13",
- old_functional_name="routed_decode",
- )
- def add_handler(self, *handler: Callable) -> None:
- self.decoder.add_handler(*handler)
- def __iter__(self) -> Iterator[Tuple[str, Any]]:
- for data in self.datapipe:
- pathname = data[0]
- result = self.decoder(data)
- yield (pathname, result[pathname])
- def __len__(self) -> int:
- if isinstance(self.datapipe, Sized):
- return len(self.datapipe)
- raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
|