dataframes.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. from typing import Any, Dict, List
  2. from torch.utils.data.datapipes._decorator import functional_datapipe
  3. from torch.utils.data.datapipes.datapipe import DFIterDataPipe, IterDataPipe
  4. from torch.utils.data.datapipes.dataframe.structures import DataChunkDF
  5. # TODO(VitalyFedyunin): Add error when two different traces get combined
  6. __all__ = [
  7. "Capture",
  8. "CaptureA",
  9. "CaptureAdd",
  10. "CaptureCall",
  11. "CaptureControl",
  12. "CaptureDataFrame",
  13. "CaptureDataFrameWithDataPipeOps",
  14. "CaptureF",
  15. "CaptureGetAttr",
  16. "CaptureGetItem",
  17. "CaptureInitial",
  18. "CaptureLikeMock",
  19. "CaptureMul",
  20. "CaptureSetItem",
  21. "CaptureSub",
  22. "CaptureVariable",
  23. "CaptureVariableAssign",
  24. "DataFrameTracer",
  25. "DataFrameTracedOps",
  26. "disable_capture",
  27. "get_val",
  28. ]
  29. def disable_capture():
  30. CaptureControl.disabled = True
  31. class CaptureControl():
  32. disabled = False
  33. class DataFrameTracedOps(DFIterDataPipe):
  34. def __init__(self, source_datapipe, output_var):
  35. self.source_datapipe = source_datapipe
  36. self.output_var = output_var
  37. def __iter__(self):
  38. for item in self.source_datapipe:
  39. yield self.output_var.apply_ops(item)
  40. # TODO(VitalyFedyunin): Extract this list from the DFIterDataPipe registred functions
  41. DATAPIPES_OPS = ['_dataframes_as_tuples', 'groupby', '_dataframes_filter', 'map', 'to_datapipe',
  42. 'shuffle', 'concat', 'batch', '_dataframes_per_row', '_dataframes_concat', '_dataframes_shuffle']
  43. UNIMPLEMENTED_ATTR = ['__deepcopy__', '__setstate__', 'is_shardable', 'apply_sharding']
  44. class Capture:
  45. # TODO: All operations are shared across entire InitialCapture, need to figure out what if we join two captures
  46. def __init__(self, schema_df=None):
  47. self.ctx = {'operations': [], 'variables': [], 'schema_df': schema_df}
  48. def __str__(self):
  49. return self._ops_str()
  50. def _ops_str(self):
  51. res = ""
  52. for op in self.ctx['operations']:
  53. if len(res) > 0:
  54. res += "\n"
  55. res += str(op)
  56. return res
  57. def __getstate__(self):
  58. # TODO(VitalyFedyunin): Currently can't pickle (why?)
  59. self.ctx['schema_df'] = None
  60. for var in self.ctx['variables']:
  61. var.calculated_value = None
  62. state = {}
  63. for item in self.__dict__:
  64. state[item] = getattr(self, item)
  65. return state
  66. def __setstate__(self, state):
  67. for k, v in state.items():
  68. setattr(self, k, v)
  69. def __getattr__(self, attrname):
  70. if attrname == 'kwarg' or attrname == 'kwargs':
  71. raise Exception('no kwargs!')
  72. if attrname in ['__deepcopy__']:
  73. raise AttributeError()
  74. result = CaptureGetAttr(self, attrname, ctx=self.ctx)
  75. return result
  76. def __getitem__(self, key):
  77. return CaptureGetItem(self, key, ctx=self.ctx)
  78. def __setitem__(self, key, value):
  79. self.ctx['operations'].append(
  80. CaptureSetItem(self, key, value, ctx=self.ctx))
  81. def __add__(self, add_val):
  82. res = CaptureAdd(self, add_val, ctx=self.ctx)
  83. var = CaptureVariable(res, ctx=self.ctx)
  84. self.ctx['operations'].append(
  85. CaptureVariableAssign(variable=var, value=res, ctx=self.ctx))
  86. return var
  87. def __sub__(self, add_val):
  88. res = CaptureSub(self, add_val, ctx=self.ctx)
  89. var = CaptureVariable(res, ctx=self.ctx)
  90. self.ctx['operations'].append(
  91. CaptureVariableAssign(variable=var, value=res, ctx=self.ctx))
  92. return var
  93. def __mul__(self, add_val):
  94. res = CaptureMul(self, add_val, ctx=self.ctx)
  95. var = CaptureVariable(res, ctx=self.ctx)
  96. t = CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
  97. self.ctx['operations'].append(t)
  98. return var
  99. def _is_context_empty(self):
  100. return len(self.ctx['operations']) == 0 and len(self.ctx['variables']) == 0
  101. def apply_ops_2(self, dataframe):
  102. # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
  103. self.ctx['variables'][0].calculated_value = dataframe
  104. for op in self.ctx['operations']:
  105. op.execute()
  106. @property
  107. def columns(self):
  108. self.apply_ops_2(self.ctx['schema_df'])
  109. value = self.execute()
  110. return value.columns
  111. # TODO(VitalyFedyunin): Add tests
  112. # TODO(VitalyFedyunin): Need to join context if one of them are empty because we used capture
  113. def __call__(self, *args, **kwargs):
  114. # TODO: Check if args or kwargs have more than one different context
  115. if self._is_context_empty():
  116. # TODO: Allow CaptureA to take context from mock
  117. for arg in args:
  118. if isinstance(arg, Capture) and not arg._is_context_empty():
  119. self.ctx = arg.ctx
  120. break
  121. if self._is_context_empty():
  122. for k, v in kwargs.items():
  123. if isinstance(k, Capture) and not k._is_context_empty():
  124. self.ctx = k.ctx
  125. break
  126. if isinstance(v, Capture) and not v._is_context_empty():
  127. self.ctx = v.ctx
  128. break
  129. res = CaptureCall(self, ctx=self.ctx, args=args, kwargs=kwargs)
  130. var = CaptureVariable(None, ctx=self.ctx)
  131. t = CaptureVariableAssign(ctx=self.ctx, variable=var, value=res)
  132. self.ctx['operations'].append(t)
  133. return var
  134. class CaptureF(Capture):
  135. def __init__(self, ctx=None, **kwargs):
  136. if ctx is None:
  137. self.ctx = {'operations': [], 'variables': []}
  138. else:
  139. self.ctx = ctx
  140. self.kwargs = kwargs
  141. class CaptureA(CaptureF):
  142. def __str__(self):
  143. return '{name}'.format(name=self.kwargs['name'])
  144. def execute(self):
  145. value = self.kwargs['real_attribute']
  146. return value
  147. class CaptureLikeMock():
  148. def __init__(self, name):
  149. import unittest.mock as mock
  150. # TODO(VitalyFedyunin): Do not use provate function here, copy own implementation instead.
  151. get_target, attribute = mock._get_target(name) # type: ignore[attr-defined]
  152. self.get_target = get_target
  153. self.attribute = attribute
  154. self.name = name
  155. def __enter__(self):
  156. self.save = getattr(self.get_target(), self.attribute)
  157. capt = CaptureA(name=self.name, real_attribute=self.save)
  158. setattr(self.get_target(), self.attribute, capt)
  159. def __exit__(self, *exc_info):
  160. setattr(self.get_target(), self.attribute, self.save)
  161. class CaptureCall(Capture):
  162. def __init__(self, callable, ctx=None, **kwargs):
  163. if ctx is None:
  164. self.ctx = {'operations': [], 'variables': []}
  165. else:
  166. self.ctx = ctx
  167. self.kwargs = kwargs
  168. self.callable = callable
  169. def __str__(self):
  170. return "{callable}({args},{kwargs})".format(callable=self.callable, **self.kwargs)
  171. def execute(self):
  172. # TODO: VitalyFedyunin execute kwargs and maybe nestted structures
  173. executed_args = []
  174. for arg in self.kwargs['args']:
  175. if isinstance(arg, Capture):
  176. executed_args.append(arg.execute())
  177. else:
  178. executed_args.append(arg)
  179. left = get_val(self.callable)
  180. return left(*executed_args, **self.kwargs['kwargs'])
  181. class CaptureVariableAssign(CaptureF):
  182. def __str__(self):
  183. variable = self.kwargs['variable']
  184. value = self.kwargs['value']
  185. return "{variable} = {value}".format(variable=variable, value=value)
  186. def execute(self):
  187. self.kwargs['variable'].calculated_value = self.kwargs['value'].execute()
  188. class CaptureVariable(Capture):
  189. # TODO(VitalyFedyunin): This should be atomic and thread safe
  190. names_idx = 0
  191. def __init__(self, value, ctx):
  192. if CaptureControl.disabled:
  193. raise Exception('Attempting to create capture variable with capture off')
  194. self.ctx = ctx
  195. self.value = value
  196. self.name = 'var_%s' % CaptureVariable.names_idx
  197. CaptureVariable.names_idx += 1
  198. self.ctx['variables'].append(self)
  199. def __str__(self):
  200. return self.name
  201. def execute(self):
  202. return self.calculated_value
  203. def apply_ops(self, dataframe):
  204. # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
  205. self.ctx['variables'][0].calculated_value = dataframe
  206. for op in self.ctx['operations']:
  207. op.execute()
  208. return self.calculated_value
  209. class CaptureGetItem(Capture):
  210. def __init__(self, left, key, ctx):
  211. self.ctx = ctx
  212. self.left = left
  213. self.key = key
  214. def __str__(self):
  215. return "%s[%s]" % (self.left, get_val(self.key))
  216. def execute(self):
  217. left = self.left.execute()
  218. return left[self.key]
  219. class CaptureSetItem(Capture):
  220. def __init__(self, left, key, value, ctx):
  221. self.ctx = ctx
  222. self.left = left
  223. self.key = key
  224. self.value = value
  225. def __str__(self):
  226. return "%s[%s] = %s" % (self.left, get_val(self.key), self.value)
  227. def execute(self):
  228. left = self.left.execute()
  229. value = self.value.execute()
  230. left[self.key] = value
  231. class CaptureAdd(Capture):
  232. def __init__(self, left, right, ctx):
  233. self.ctx = ctx
  234. self.left = left
  235. self.right = right
  236. def __str__(self):
  237. return "%s + %s" % (self.left, self.right)
  238. def execute(self):
  239. return get_val(self.left) + get_val(self.right)
  240. class CaptureMul(Capture):
  241. def __init__(self, left, right, ctx):
  242. self.ctx = ctx
  243. self.left = left
  244. self.right = right
  245. def __str__(self):
  246. return "%s * %s" % (self.left, self.right)
  247. def execute(self):
  248. return get_val(self.left) * get_val(self.right)
  249. class CaptureSub(Capture):
  250. def __init__(self, left, right, ctx):
  251. self.ctx = ctx
  252. self.left = left
  253. self.right = right
  254. def __str__(self):
  255. return "%s - %s" % (self.left, self.right)
  256. def execute(self):
  257. return get_val(self.left) - get_val(self.right)
  258. class CaptureGetAttr(Capture):
  259. def __init__(self, src, name, ctx):
  260. self.ctx = ctx
  261. self.src = src
  262. self.name = name
  263. def __str__(self):
  264. return "%s.%s" % (self.src, self.name)
  265. def execute(self):
  266. val = get_val(self.src)
  267. return getattr(val, self.name)
  268. def get_val(capture):
  269. if isinstance(capture, Capture):
  270. return capture.execute()
  271. elif isinstance(capture, str):
  272. return '"%s"' % capture
  273. else:
  274. return capture
  275. class CaptureInitial(CaptureVariable):
  276. def __init__(self, schema_df=None):
  277. new_ctx: Dict[str, List[Any]] = {'operations': [], 'variables': [], 'schema_df': schema_df}
  278. super().__init__(None, new_ctx)
  279. self.name = 'input_%s' % self.name
  280. class CaptureDataFrame(CaptureInitial):
  281. pass
  282. class CaptureDataFrameWithDataPipeOps(CaptureDataFrame):
  283. def as_datapipe(self):
  284. return DataFrameTracedOps(
  285. self.ctx['variables'][0].source_datapipe, self)
  286. def raw_iterator(self):
  287. return self.as_datapipe().__iter__()
  288. def __iter__(self):
  289. return iter(self._dataframes_as_tuples())
  290. def batch(self, batch_size=10, drop_last: bool = False, wrapper_class=DataChunkDF):
  291. dp = self._dataframes_per_row()._dataframes_concat(batch_size)
  292. dp = dp.as_datapipe().batch(1, drop_last=drop_last, wrapper_class=wrapper_class)
  293. dp._dp_contains_dataframe = True
  294. return dp
  295. def groupby(self,
  296. group_key_fn,
  297. *,
  298. buffer_size=10000,
  299. group_size=None,
  300. guaranteed_group_size=None,
  301. drop_remaining=False):
  302. dp = self._dataframes_per_row()
  303. dp = dp.as_datapipe().groupby(group_key_fn, buffer_size=buffer_size, group_size=group_size,
  304. guaranteed_group_size=guaranteed_group_size, drop_remaining=drop_remaining)
  305. return dp
  306. def shuffle(self, *args, **kwargs):
  307. return self._dataframes_shuffle(*args, **kwargs)
  308. def filter(self, *args, **kwargs):
  309. return self._dataframes_filter(*args, **kwargs)
  310. def collate(self, *args, **kwargs):
  311. raise Exception("Can't collate unbatched DataFrames stream")
  312. def __getattr__(self, attrname): # ?
  313. if attrname in UNIMPLEMENTED_ATTR:
  314. raise AttributeError('Attempting to get ', attrname)
  315. if attrname in DATAPIPES_OPS:
  316. return (self.as_datapipe()).__getattr__(attrname)
  317. return super().__getattr__(attrname)
  318. @functional_datapipe('trace_as_dataframe')
  319. class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe):
  320. source_datapipe = None
  321. # TODO(VitalyFedyunin): Must implement all special functions of datapipes
  322. def set_shuffle_settings(self, *args, **kwargs):
  323. pass
  324. def is_shardable(self):
  325. return False
  326. def __init__(self, source_datapipe, schema_df=None):
  327. self.source_datapipe = source_datapipe
  328. if schema_df is None:
  329. schema_df = next(iter(self.source_datapipe))
  330. super().__init__(schema_df=schema_df)