import copy import warnings from torch.utils.data.datapipes.datapipe import MapDataPipe __all__ = ["SequenceWrapperMapDataPipe", ] class SequenceWrapperMapDataPipe(MapDataPipe): r""" Wraps a sequence object into a MapDataPipe. Args: sequence: Sequence object to be wrapped into an MapDataPipe deepcopy: Option to deepcopy input sequence object .. note:: If ``deepcopy`` is set to False explicitly, users should ensure that data pipeline doesn't contain any in-place operations over the iterable instance, in order to prevent data inconsistency across iterations. Example: >>> # xdoctest: +SKIP >>> from torchdata.datapipes.map import SequenceWrapper >>> dp = SequenceWrapper(range(10)) >>> list(dp) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] >>> dp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400}) >>> dp['a'] 100 """ def __init__(self, sequence, deepcopy=True): if deepcopy: try: self.sequence = copy.deepcopy(sequence) except TypeError: warnings.warn( "The input sequence can not be deepcopied, " "please be aware of in-place modification would affect source data" ) self.sequence = sequence else: self.sequence = sequence def __getitem__(self, index): return self.sequence[index] def __len__(self): return len(self.sequence)