__init__.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # TODO(VitalyFedyunin): Rearranging this imports leads to crash,
  2. # need to cleanup dependencies and fix it
  3. from torch.utils.data.sampler import (
  4. BatchSampler,
  5. RandomSampler,
  6. Sampler,
  7. SequentialSampler,
  8. SubsetRandomSampler,
  9. WeightedRandomSampler,
  10. )
  11. from torch.utils.data.dataset import (
  12. ChainDataset,
  13. ConcatDataset,
  14. Dataset,
  15. IterableDataset,
  16. Subset,
  17. TensorDataset,
  18. random_split,
  19. )
  20. from torch.utils.data.datapipes.datapipe import (
  21. DFIterDataPipe,
  22. DataChunk,
  23. IterDataPipe,
  24. MapDataPipe,
  25. )
  26. from torch.utils.data.dataloader import (
  27. DataLoader,
  28. _DatasetKind,
  29. get_worker_info,
  30. default_collate,
  31. default_convert,
  32. )
  33. from torch.utils.data.distributed import DistributedSampler
  34. from torch.utils.data.datapipes._decorator import (
  35. argument_validation,
  36. functional_datapipe,
  37. guaranteed_datapipes_determinism,
  38. non_deterministic,
  39. runtime_validation,
  40. runtime_validation_disabled,
  41. )
  42. __all__ = ['BatchSampler',
  43. 'ChainDataset',
  44. 'ConcatDataset',
  45. 'DFIterDataPipe',
  46. 'DataChunk',
  47. 'DataLoader',
  48. 'Dataset',
  49. 'DistributedSampler',
  50. 'IterDataPipe',
  51. 'IterableDataset',
  52. 'MapDataPipe',
  53. 'RandomSampler',
  54. 'Sampler',
  55. 'SequentialSampler',
  56. 'Subset',
  57. 'SubsetRandomSampler',
  58. 'TensorDataset',
  59. 'WeightedRandomSampler',
  60. '_DatasetKind',
  61. 'argument_validation',
  62. 'default_collate',
  63. 'default_convert',
  64. 'functional_datapipe',
  65. 'get_worker_info',
  66. 'guaranteed_datapipes_determinism',
  67. 'non_deterministic',
  68. 'random_split',
  69. 'runtime_validation',
  70. 'runtime_validation_disabled']
  71. # Please keep this list sorted
  72. assert __all__ == sorted(__all__)