collate.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to
  2. collate samples fetched from dataset into Tensor(s).
  3. These **needs** to be in global scope since Py2 doesn't support serializing
  4. static methods.
  5. `default_collate` and `default_convert` are exposed to users via 'dataloader.py'.
  6. """
  7. import collections
  8. import contextlib
  9. import re
  10. import torch
  11. from typing import Callable, Dict, Optional, Tuple, Type, Union
  12. np_str_obj_array_pattern = re.compile(r'[SaUO]')
  13. def default_convert(data):
  14. r"""
  15. Function that converts each NumPy array element into a :class:`torch.Tensor`. If the input is a `Sequence`,
  16. `Collection`, or `Mapping`, it tries to convert each element inside to a :class:`torch.Tensor`.
  17. If the input is not an NumPy array, it is left unchanged.
  18. This is used as the default function for collation when both `batch_sampler` and
  19. `batch_size` are NOT defined in :class:`~torch.utils.data.DataLoader`.
  20. The general input type to output type mapping is similar to that
  21. of :func:`~torch.utils.data.default_collate`. See the description there for more details.
  22. Args:
  23. data: a single data point to be converted
  24. Examples:
  25. >>> # xdoctest: +SKIP
  26. >>> # Example with `int`
  27. >>> default_convert(0)
  28. 0
  29. >>> # Example with NumPy array
  30. >>> default_convert(np.array([0, 1]))
  31. tensor([0, 1])
  32. >>> # Example with NamedTuple
  33. >>> Point = namedtuple('Point', ['x', 'y'])
  34. >>> default_convert(Point(0, 0))
  35. Point(x=0, y=0)
  36. >>> default_convert(Point(np.array(0), np.array(0)))
  37. Point(x=tensor(0), y=tensor(0))
  38. >>> # Example with List
  39. >>> default_convert([np.array([0, 1]), np.array([2, 3])])
  40. [tensor([0, 1]), tensor([2, 3])]
  41. """
  42. elem_type = type(data)
  43. if isinstance(data, torch.Tensor):
  44. return data
  45. elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
  46. and elem_type.__name__ != 'string_':
  47. # array of string classes and object
  48. if elem_type.__name__ == 'ndarray' \
  49. and np_str_obj_array_pattern.search(data.dtype.str) is not None:
  50. return data
  51. return torch.as_tensor(data)
  52. elif isinstance(data, collections.abc.Mapping):
  53. try:
  54. return elem_type({key: default_convert(data[key]) for key in data})
  55. except TypeError:
  56. # The mapping type may not support `__init__(iterable)`.
  57. return {key: default_convert(data[key]) for key in data}
  58. elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
  59. return elem_type(*(default_convert(d) for d in data))
  60. elif isinstance(data, tuple):
  61. return [default_convert(d) for d in data] # Backwards compatibility.
  62. elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str):
  63. try:
  64. return elem_type([default_convert(d) for d in data])
  65. except TypeError:
  66. # The sequence type may not support `__init__(iterable)` (e.g., `range`).
  67. return [default_convert(d) for d in data]
  68. else:
  69. return data
  70. default_collate_err_msg_format = (
  71. "default_collate: batch must contain tensors, numpy arrays, numbers, "
  72. "dicts or lists; found {}")
  73. def collate(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
  74. r"""
  75. General collate function that handles collection type of element within each batch
  76. and opens function registry to deal with specific element types. `default_collate_fn_map`
  77. provides default collate functions for tensors, numpy arrays, numbers and strings.
  78. Args:
  79. batch: a single batch to be collated
  80. collate_fn_map: Optional dictionary mapping from element type to the corresponding collate function.
  81. If the element type isn't present in this dictionary,
  82. this function will go through each key of the dictionary in the insertion order to
  83. invoke the corresponding collate function if the element type is a subclass of the key.
  84. Examples:
  85. >>> # Extend this function to handle batch of tensors
  86. >>> def collate_tensor_fn(batch, *, collate_fn_map):
  87. ... return torch.stack(batch, 0)
  88. >>> def custom_collate(batch):
  89. ... collate_map = {torch.Tensor: collate_tensor_fn}
  90. ... return collate(batch, collate_fn_map=collate_map)
  91. >>> # Extend `default_collate` by in-place modifying `default_collate_fn_map`
  92. >>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn})
  93. Note:
  94. Each collate function requires a positional argument for batch and a keyword argument
  95. for the dictionary of collate functions as `collate_fn_map`.
  96. """
  97. elem = batch[0]
  98. elem_type = type(elem)
  99. if collate_fn_map is not None:
  100. if elem_type in collate_fn_map:
  101. return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
  102. for collate_type in collate_fn_map:
  103. if isinstance(elem, collate_type):
  104. return collate_fn_map[collate_type](batch, collate_fn_map=collate_fn_map)
  105. if isinstance(elem, collections.abc.Mapping):
  106. try:
  107. return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
  108. except TypeError:
  109. # The mapping type may not support `__init__(iterable)`.
  110. return {key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}
  111. elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
  112. return elem_type(*(collate(samples, collate_fn_map=collate_fn_map) for samples in zip(*batch)))
  113. elif isinstance(elem, collections.abc.Sequence):
  114. # check to make sure that the elements in batch have consistent size
  115. it = iter(batch)
  116. elem_size = len(next(it))
  117. if not all(len(elem) == elem_size for elem in it):
  118. raise RuntimeError('each element in list of batch should be of equal size')
  119. transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.
  120. if isinstance(elem, tuple):
  121. return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed] # Backwards compatibility.
  122. else:
  123. try:
  124. return elem_type([collate(samples, collate_fn_map=collate_fn_map) for samples in transposed])
  125. except TypeError:
  126. # The sequence type may not support `__init__(iterable)` (e.g., `range`).
  127. return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]
  128. raise TypeError(default_collate_err_msg_format.format(elem_type))
  129. def collate_tensor_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
  130. elem = batch[0]
  131. out = None
  132. if torch.utils.data.get_worker_info() is not None:
  133. # If we're in a background process, concatenate directly into a
  134. # shared memory tensor to avoid an extra copy
  135. numel = sum(x.numel() for x in batch)
  136. storage = elem._typed_storage()._new_shared(numel, device=elem.device)
  137. out = elem.new(storage).resize_(len(batch), *list(elem.size()))
  138. return torch.stack(batch, 0, out=out)
  139. def collate_numpy_array_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
  140. elem = batch[0]
  141. # array of string classes and object
  142. if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
  143. raise TypeError(default_collate_err_msg_format.format(elem.dtype))
  144. return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)
  145. def collate_numpy_scalar_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
  146. return torch.as_tensor(batch)
  147. def collate_float_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
  148. return torch.tensor(batch, dtype=torch.float64)
  149. def collate_int_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
  150. return torch.tensor(batch)
  151. def collate_str_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
  152. return batch
  153. default_collate_fn_map: Dict[Union[Type, Tuple[Type, ...]], Callable] = {torch.Tensor: collate_tensor_fn}
  154. with contextlib.suppress(ImportError):
  155. import numpy as np
  156. # For both ndarray and memmap (subclass of ndarray)
  157. default_collate_fn_map[np.ndarray] = collate_numpy_array_fn
  158. # See scalars hierarchy: https://numpy.org/doc/stable/reference/arrays.scalars.html
  159. # Skip string scalars
  160. default_collate_fn_map[(np.bool_, np.number, np.object_)] = collate_numpy_scalar_fn
  161. default_collate_fn_map[float] = collate_float_fn
  162. default_collate_fn_map[int] = collate_int_fn
  163. default_collate_fn_map[str] = collate_str_fn
  164. def default_collate(batch):
  165. r"""
  166. Function that takes in a batch of data and puts the elements within the batch
  167. into a tensor with an additional outer dimension - batch size. The exact output type can be
  168. a :class:`torch.Tensor`, a `Sequence` of :class:`torch.Tensor`, a
  169. Collection of :class:`torch.Tensor`, or left unchanged, depending on the input type.
  170. This is used as the default function for collation when
  171. `batch_size` or `batch_sampler` is defined in :class:`~torch.utils.data.DataLoader`.
  172. Here is the general input type (based on the type of the element within the batch) to output type mapping:
  173. * :class:`torch.Tensor` -> :class:`torch.Tensor` (with an added outer dimension batch size)
  174. * NumPy Arrays -> :class:`torch.Tensor`
  175. * `float` -> :class:`torch.Tensor`
  176. * `int` -> :class:`torch.Tensor`
  177. * `str` -> `str` (unchanged)
  178. * `bytes` -> `bytes` (unchanged)
  179. * `Mapping[K, V_i]` -> `Mapping[K, default_collate([V_1, V_2, ...])]`
  180. * `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[default_collate([V1_1, V1_2, ...]),
  181. default_collate([V2_1, V2_2, ...]), ...]`
  182. * `Sequence[V1_i, V2_i, ...]` -> `Sequence[default_collate([V1_1, V1_2, ...]),
  183. default_collate([V2_1, V2_2, ...]), ...]`
  184. Args:
  185. batch: a single batch to be collated
  186. Examples:
  187. >>> # xdoctest: +SKIP
  188. >>> # Example with a batch of `int`s:
  189. >>> default_collate([0, 1, 2, 3])
  190. tensor([0, 1, 2, 3])
  191. >>> # Example with a batch of `str`s:
  192. >>> default_collate(['a', 'b', 'c'])
  193. ['a', 'b', 'c']
  194. >>> # Example with `Map` inside the batch:
  195. >>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
  196. {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])}
  197. >>> # Example with `NamedTuple` inside the batch:
  198. >>> Point = namedtuple('Point', ['x', 'y'])
  199. >>> default_collate([Point(0, 0), Point(1, 1)])
  200. Point(x=tensor([0, 1]), y=tensor([0, 1]))
  201. >>> # Example with `Tuple` inside the batch:
  202. >>> default_collate([(0, 1), (2, 3)])
  203. [tensor([0, 2]), tensor([1, 3])]
  204. >>> # Example with `List` inside the batch:
  205. >>> default_collate([[0, 1], [2, 3]])
  206. [tensor([0, 2]), tensor([1, 3])]
  207. >>> # Two options to extend `default_collate` to handle specific type
  208. >>> # Option 1: Write custom collate function and invoke `default_collate`
  209. >>> def custom_collate(batch):
  210. ... elem = batch[0]
  211. ... if isinstance(elem, CustomType): # Some custom condition
  212. ... return ...
  213. ... else: # Fall back to `default_collate`
  214. ... return default_collate(batch)
  215. >>> # Option 2: In-place modify `default_collate_fn_map`
  216. >>> def collate_customtype_fn(batch, *, collate_fn_map=None):
  217. ... return ...
  218. >>> default_collate_fn_map.update(CustoType, collate_customtype_fn)
  219. >>> default_collate(batch) # Handle `CustomType` automatically
  220. """
  221. return collate(batch, collate_fn_map=default_collate_fn_map)