datapipes.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import random
  2. from torch.utils.data.datapipes._decorator import functional_datapipe
  3. from torch.utils.data.datapipes.datapipe import DFIterDataPipe, IterDataPipe
  4. from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
  5. __all__ = [
  6. "ConcatDataFramesPipe",
  7. "DataFramesAsTuplesPipe",
  8. "ExampleAggregateAsDataFrames",
  9. "FilterDataFramesPipe",
  10. "PerRowDataFramesPipe",
  11. "ShuffleDataFramesPipe",
  12. ]
  13. @functional_datapipe('_dataframes_as_tuples')
  14. class DataFramesAsTuplesPipe(IterDataPipe):
  15. def __init__(self, source_datapipe):
  16. self.source_datapipe = source_datapipe
  17. def __iter__(self):
  18. for df in self.source_datapipe:
  19. # for record in df.to_records(index=False):
  20. for record in df_wrapper.iterate(df):
  21. yield record
  22. @functional_datapipe('_dataframes_per_row', enable_df_api_tracing=True)
  23. class PerRowDataFramesPipe(DFIterDataPipe):
  24. def __init__(self, source_datapipe):
  25. self.source_datapipe = source_datapipe
  26. def __iter__(self):
  27. for df in self.source_datapipe:
  28. # TODO(VitalyFedyunin): Replacing with TorchArrow only API, as we are dropping pandas as followup
  29. for i in range(len(df)):
  30. yield df[i:i + 1]
  31. @functional_datapipe('_dataframes_concat', enable_df_api_tracing=True)
  32. class ConcatDataFramesPipe(DFIterDataPipe):
  33. def __init__(self, source_datapipe, batch=3):
  34. self.source_datapipe = source_datapipe
  35. self.n_batch = batch
  36. def __iter__(self):
  37. buffer = []
  38. for df in self.source_datapipe:
  39. buffer.append(df)
  40. if len(buffer) == self.n_batch:
  41. yield df_wrapper.concat(buffer)
  42. buffer = []
  43. if len(buffer):
  44. yield df_wrapper.concat(buffer)
  45. @functional_datapipe('_dataframes_shuffle', enable_df_api_tracing=True)
  46. class ShuffleDataFramesPipe(DFIterDataPipe):
  47. def __init__(self, source_datapipe):
  48. self.source_datapipe = source_datapipe
  49. def __iter__(self):
  50. size = None
  51. all_buffer = []
  52. for df in self.source_datapipe:
  53. if size is None:
  54. size = df_wrapper.get_len(df)
  55. for i in range(df_wrapper.get_len(df)):
  56. all_buffer.append(df_wrapper.get_item(df, i))
  57. random.shuffle(all_buffer)
  58. buffer = []
  59. for df in all_buffer:
  60. buffer.append(df)
  61. if len(buffer) == size:
  62. yield df_wrapper.concat(buffer)
  63. buffer = []
  64. if len(buffer):
  65. yield df_wrapper.concat(buffer)
  66. @functional_datapipe('_dataframes_filter', enable_df_api_tracing=True)
  67. class FilterDataFramesPipe(DFIterDataPipe):
  68. def __init__(self, source_datapipe, filter_fn):
  69. self.source_datapipe = source_datapipe
  70. self.filter_fn = filter_fn
  71. def __iter__(self):
  72. size = None
  73. all_buffer = []
  74. filter_res = []
  75. for df in self.source_datapipe:
  76. if size is None:
  77. size = len(df.index)
  78. for i in range(len(df.index)):
  79. all_buffer.append(df[i:i + 1])
  80. filter_res.append(self.filter_fn(df.iloc[i]))
  81. buffer = []
  82. for df, res in zip(all_buffer, filter_res):
  83. if res:
  84. buffer.append(df)
  85. if len(buffer) == size:
  86. yield df_wrapper.concat(buffer)
  87. buffer = []
  88. if len(buffer):
  89. yield df_wrapper.concat(buffer)
  90. @functional_datapipe('_to_dataframes_pipe', enable_df_api_tracing=True)
  91. class ExampleAggregateAsDataFrames(DFIterDataPipe):
  92. def __init__(self, source_datapipe, dataframe_size=10, columns=None):
  93. self.source_datapipe = source_datapipe
  94. self.columns = columns
  95. self.dataframe_size = dataframe_size
  96. def _as_list(self, item):
  97. try:
  98. return list(item)
  99. except Exception: # TODO(VitalyFedyunin): Replace with better iterable exception
  100. return [item]
  101. def __iter__(self):
  102. aggregate = []
  103. for item in self.source_datapipe:
  104. aggregate.append(self._as_list(item))
  105. if len(aggregate) == self.dataframe_size:
  106. yield df_wrapper.create_dataframe(aggregate, columns=self.columns)
  107. aggregate = []
  108. if len(aggregate) > 0:
  109. yield df_wrapper.create_dataframe(aggregate, columns=self.columns)