__init__.py 645 B

1234567891011121314
  1. from .parallel_apply import parallel_apply
  2. from .replicate import replicate
  3. from .data_parallel import DataParallel, data_parallel
  4. from .scatter_gather import scatter, gather
  5. from .distributed import DistributedDataParallel
  6. __all__ = ['replicate', 'scatter', 'parallel_apply', 'gather', 'data_parallel',
  7. 'DataParallel', 'DistributedDataParallel']
  8. def DistributedDataParallelCPU(*args, **kwargs):
  9. import warnings
  10. warnings.warn("torch.nn.parallel.DistributedDataParallelCPU is deprecated, "
  11. "please use torch.nn.parallel.DistributedDataParallel instead.")
  12. return DistributedDataParallel(*args, **kwargs)