123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433 |
- from typing import Any, Dict, List
- from torch.utils.data.datapipes._decorator import functional_datapipe
- from torch.utils.data.datapipes.datapipe import DFIterDataPipe, IterDataPipe
- from torch.utils.data.datapipes.dataframe.structures import DataChunkDF
- # TODO(VitalyFedyunin): Add error when two different traces get combined
- __all__ = [
- "Capture",
- "CaptureA",
- "CaptureAdd",
- "CaptureCall",
- "CaptureControl",
- "CaptureDataFrame",
- "CaptureDataFrameWithDataPipeOps",
- "CaptureF",
- "CaptureGetAttr",
- "CaptureGetItem",
- "CaptureInitial",
- "CaptureLikeMock",
- "CaptureMul",
- "CaptureSetItem",
- "CaptureSub",
- "CaptureVariable",
- "CaptureVariableAssign",
- "DataFrameTracer",
- "DataFrameTracedOps",
- "disable_capture",
- "get_val",
- ]
- def disable_capture():
- CaptureControl.disabled = True
- class CaptureControl():
- disabled = False
- class DataFrameTracedOps(DFIterDataPipe):
- def __init__(self, source_datapipe, output_var):
- self.source_datapipe = source_datapipe
- self.output_var = output_var
- def __iter__(self):
- for item in self.source_datapipe:
- yield self.output_var.apply_ops(item)
- # TODO(VitalyFedyunin): Extract this list from the DFIterDataPipe registred functions
- DATAPIPES_OPS = ['_dataframes_as_tuples', 'groupby', '_dataframes_filter', 'map', 'to_datapipe',
- 'shuffle', 'concat', 'batch', '_dataframes_per_row', '_dataframes_concat', '_dataframes_shuffle']
- UNIMPLEMENTED_ATTR = ['__deepcopy__', '__setstate__', 'is_shardable', 'apply_sharding']
- class Capture:
- # TODO: All operations are shared across entire InitialCapture, need to figure out what if we join two captures
- def __init__(self, schema_df=None):
- self.ctx = {'operations': [], 'variables': [], 'schema_df': schema_df}
- def __str__(self):
- return self._ops_str()
- def _ops_str(self):
- res = ""
- for op in self.ctx['operations']:
- if len(res) > 0:
- res += "\n"
- res += str(op)
- return res
- def __getstate__(self):
- # TODO(VitalyFedyunin): Currently can't pickle (why?)
- self.ctx['schema_df'] = None
- for var in self.ctx['variables']:
- var.calculated_value = None
- state = {}
- for item in self.__dict__:
- state[item] = getattr(self, item)
- return state
- def __setstate__(self, state):
- for k, v in state.items():
- setattr(self, k, v)
- def __getattr__(self, attrname):
- if attrname == 'kwarg' or attrname == 'kwargs':
- raise Exception('no kwargs!')
- if attrname in ['__deepcopy__']:
- raise AttributeError()
- result = CaptureGetAttr(self, attrname, ctx=self.ctx)
- return result
- def __getitem__(self, key):
- return CaptureGetItem(self, key, ctx=self.ctx)
- def __setitem__(self, key, value):
- self.ctx['operations'].append(
- CaptureSetItem(self, key, value, ctx=self.ctx))
- def __add__(self, add_val):
- res = CaptureAdd(self, add_val, ctx=self.ctx)
- var = CaptureVariable(res, ctx=self.ctx)
- self.ctx['operations'].append(
- CaptureVariableAssign(variable=var, value=res, ctx=self.ctx))
- return var
- def __sub__(self, add_val):
- res = CaptureSub(self, add_val, ctx=self.ctx)
- var = CaptureVariable(res, ctx=self.ctx)
- self.ctx['operations'].append(
- CaptureVariableAssign(variable=var, value=res, ctx=self.ctx))
- return var
- def __mul__(self, add_val):
- res = CaptureMul(self, add_val, ctx=self.ctx)
- var = CaptureVariable(res, ctx=self.ctx)
- t = CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
- self.ctx['operations'].append(t)
- return var
- def _is_context_empty(self):
- return len(self.ctx['operations']) == 0 and len(self.ctx['variables']) == 0
- def apply_ops_2(self, dataframe):
- # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
- self.ctx['variables'][0].calculated_value = dataframe
- for op in self.ctx['operations']:
- op.execute()
- @property
- def columns(self):
- self.apply_ops_2(self.ctx['schema_df'])
- value = self.execute()
- return value.columns
- # TODO(VitalyFedyunin): Add tests
- # TODO(VitalyFedyunin): Need to join context if one of them are empty because we used capture
- def __call__(self, *args, **kwargs):
- # TODO: Check if args or kwargs have more than one different context
- if self._is_context_empty():
- # TODO: Allow CaptureA to take context from mock
- for arg in args:
- if isinstance(arg, Capture) and not arg._is_context_empty():
- self.ctx = arg.ctx
- break
- if self._is_context_empty():
- for k, v in kwargs.items():
- if isinstance(k, Capture) and not k._is_context_empty():
- self.ctx = k.ctx
- break
- if isinstance(v, Capture) and not v._is_context_empty():
- self.ctx = v.ctx
- break
- res = CaptureCall(self, ctx=self.ctx, args=args, kwargs=kwargs)
- var = CaptureVariable(None, ctx=self.ctx)
- t = CaptureVariableAssign(ctx=self.ctx, variable=var, value=res)
- self.ctx['operations'].append(t)
- return var
- class CaptureF(Capture):
- def __init__(self, ctx=None, **kwargs):
- if ctx is None:
- self.ctx = {'operations': [], 'variables': []}
- else:
- self.ctx = ctx
- self.kwargs = kwargs
- class CaptureA(CaptureF):
- def __str__(self):
- return '{name}'.format(name=self.kwargs['name'])
- def execute(self):
- value = self.kwargs['real_attribute']
- return value
- class CaptureLikeMock():
- def __init__(self, name):
- import unittest.mock as mock
- # TODO(VitalyFedyunin): Do not use provate function here, copy own implementation instead.
- get_target, attribute = mock._get_target(name) # type: ignore[attr-defined]
- self.get_target = get_target
- self.attribute = attribute
- self.name = name
- def __enter__(self):
- self.save = getattr(self.get_target(), self.attribute)
- capt = CaptureA(name=self.name, real_attribute=self.save)
- setattr(self.get_target(), self.attribute, capt)
- def __exit__(self, *exc_info):
- setattr(self.get_target(), self.attribute, self.save)
- class CaptureCall(Capture):
- def __init__(self, callable, ctx=None, **kwargs):
- if ctx is None:
- self.ctx = {'operations': [], 'variables': []}
- else:
- self.ctx = ctx
- self.kwargs = kwargs
- self.callable = callable
- def __str__(self):
- return "{callable}({args},{kwargs})".format(callable=self.callable, **self.kwargs)
- def execute(self):
- # TODO: VitalyFedyunin execute kwargs and maybe nestted structures
- executed_args = []
- for arg in self.kwargs['args']:
- if isinstance(arg, Capture):
- executed_args.append(arg.execute())
- else:
- executed_args.append(arg)
- left = get_val(self.callable)
- return left(*executed_args, **self.kwargs['kwargs'])
- class CaptureVariableAssign(CaptureF):
- def __str__(self):
- variable = self.kwargs['variable']
- value = self.kwargs['value']
- return "{variable} = {value}".format(variable=variable, value=value)
- def execute(self):
- self.kwargs['variable'].calculated_value = self.kwargs['value'].execute()
- class CaptureVariable(Capture):
- # TODO(VitalyFedyunin): This should be atomic and thread safe
- names_idx = 0
- def __init__(self, value, ctx):
- if CaptureControl.disabled:
- raise Exception('Attempting to create capture variable with capture off')
- self.ctx = ctx
- self.value = value
- self.name = 'var_%s' % CaptureVariable.names_idx
- CaptureVariable.names_idx += 1
- self.ctx['variables'].append(self)
- def __str__(self):
- return self.name
- def execute(self):
- return self.calculated_value
- def apply_ops(self, dataframe):
- # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
- self.ctx['variables'][0].calculated_value = dataframe
- for op in self.ctx['operations']:
- op.execute()
- return self.calculated_value
- class CaptureGetItem(Capture):
- def __init__(self, left, key, ctx):
- self.ctx = ctx
- self.left = left
- self.key = key
- def __str__(self):
- return "%s[%s]" % (self.left, get_val(self.key))
- def execute(self):
- left = self.left.execute()
- return left[self.key]
- class CaptureSetItem(Capture):
- def __init__(self, left, key, value, ctx):
- self.ctx = ctx
- self.left = left
- self.key = key
- self.value = value
- def __str__(self):
- return "%s[%s] = %s" % (self.left, get_val(self.key), self.value)
- def execute(self):
- left = self.left.execute()
- value = self.value.execute()
- left[self.key] = value
- class CaptureAdd(Capture):
- def __init__(self, left, right, ctx):
- self.ctx = ctx
- self.left = left
- self.right = right
- def __str__(self):
- return "%s + %s" % (self.left, self.right)
- def execute(self):
- return get_val(self.left) + get_val(self.right)
- class CaptureMul(Capture):
- def __init__(self, left, right, ctx):
- self.ctx = ctx
- self.left = left
- self.right = right
- def __str__(self):
- return "%s * %s" % (self.left, self.right)
- def execute(self):
- return get_val(self.left) * get_val(self.right)
- class CaptureSub(Capture):
- def __init__(self, left, right, ctx):
- self.ctx = ctx
- self.left = left
- self.right = right
- def __str__(self):
- return "%s - %s" % (self.left, self.right)
- def execute(self):
- return get_val(self.left) - get_val(self.right)
- class CaptureGetAttr(Capture):
- def __init__(self, src, name, ctx):
- self.ctx = ctx
- self.src = src
- self.name = name
- def __str__(self):
- return "%s.%s" % (self.src, self.name)
- def execute(self):
- val = get_val(self.src)
- return getattr(val, self.name)
- def get_val(capture):
- if isinstance(capture, Capture):
- return capture.execute()
- elif isinstance(capture, str):
- return '"%s"' % capture
- else:
- return capture
- class CaptureInitial(CaptureVariable):
- def __init__(self, schema_df=None):
- new_ctx: Dict[str, List[Any]] = {'operations': [], 'variables': [], 'schema_df': schema_df}
- super().__init__(None, new_ctx)
- self.name = 'input_%s' % self.name
- class CaptureDataFrame(CaptureInitial):
- pass
- class CaptureDataFrameWithDataPipeOps(CaptureDataFrame):
- def as_datapipe(self):
- return DataFrameTracedOps(
- self.ctx['variables'][0].source_datapipe, self)
- def raw_iterator(self):
- return self.as_datapipe().__iter__()
- def __iter__(self):
- return iter(self._dataframes_as_tuples())
- def batch(self, batch_size=10, drop_last: bool = False, wrapper_class=DataChunkDF):
- dp = self._dataframes_per_row()._dataframes_concat(batch_size)
- dp = dp.as_datapipe().batch(1, drop_last=drop_last, wrapper_class=wrapper_class)
- dp._dp_contains_dataframe = True
- return dp
- def groupby(self,
- group_key_fn,
- *,
- buffer_size=10000,
- group_size=None,
- guaranteed_group_size=None,
- drop_remaining=False):
- dp = self._dataframes_per_row()
- dp = dp.as_datapipe().groupby(group_key_fn, buffer_size=buffer_size, group_size=group_size,
- guaranteed_group_size=guaranteed_group_size, drop_remaining=drop_remaining)
- return dp
- def shuffle(self, *args, **kwargs):
- return self._dataframes_shuffle(*args, **kwargs)
- def filter(self, *args, **kwargs):
- return self._dataframes_filter(*args, **kwargs)
- def collate(self, *args, **kwargs):
- raise Exception("Can't collate unbatched DataFrames stream")
- def __getattr__(self, attrname): # ?
- if attrname in UNIMPLEMENTED_ATTR:
- raise AttributeError('Attempting to get ', attrname)
- if attrname in DATAPIPES_OPS:
- return (self.as_datapipe()).__getattr__(attrname)
- return super().__getattr__(attrname)
- @functional_datapipe('trace_as_dataframe')
- class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe):
- source_datapipe = None
- # TODO(VitalyFedyunin): Must implement all special functions of datapipes
- def set_shuffle_settings(self, *args, **kwargs):
- pass
- def is_shardable(self):
- return False
- def __init__(self, source_datapipe, schema_df=None):
- self.source_datapipe = source_datapipe
- if schema_df is None:
- schema_df = next(iter(self.source_datapipe))
- super().__init__(schema_df=schema_df)
|