random.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import torch
  2. from typing import cast, Iterable, List, Union
  3. from . import _lazy_init, _lazy_call, device_count, current_device
  4. from .. import Tensor
  5. __all__ = ['get_rng_state', 'get_rng_state_all',
  6. 'set_rng_state', 'set_rng_state_all',
  7. 'manual_seed', 'manual_seed_all',
  8. 'seed', 'seed_all', 'initial_seed']
  9. def get_rng_state(device: Union[int, str, torch.device] = 'cuda') -> Tensor:
  10. r"""Returns the random number generator state of the specified GPU as a ByteTensor.
  11. Args:
  12. device (torch.device or int, optional): The device to return the RNG state of.
  13. Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
  14. .. warning::
  15. This function eagerly initializes CUDA.
  16. """
  17. _lazy_init()
  18. if isinstance(device, str):
  19. device = torch.device(device)
  20. elif isinstance(device, int):
  21. device = torch.device('cuda', device)
  22. idx = device.index
  23. if idx is None:
  24. idx = current_device()
  25. default_generator = torch.cuda.default_generators[idx]
  26. return default_generator.get_state()
  27. def get_rng_state_all() -> List[Tensor]:
  28. r"""Returns a list of ByteTensor representing the random number states of all devices."""
  29. results = []
  30. for i in range(device_count()):
  31. results.append(get_rng_state(i))
  32. return results
  33. def set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'cuda') -> None:
  34. r"""Sets the random number generator state of the specified GPU.
  35. Args:
  36. new_state (torch.ByteTensor): The desired state
  37. device (torch.device or int, optional): The device to set the RNG state.
  38. Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
  39. """
  40. new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
  41. if isinstance(device, str):
  42. device = torch.device(device)
  43. elif isinstance(device, int):
  44. device = torch.device('cuda', device)
  45. def cb():
  46. idx = cast(torch.device, device).index
  47. if idx is None:
  48. idx = current_device()
  49. default_generator = torch.cuda.default_generators[idx]
  50. default_generator.set_state(new_state_copy)
  51. _lazy_call(cb)
  52. def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
  53. r"""Sets the random number generator state of all devices.
  54. Args:
  55. new_states (Iterable of torch.ByteTensor): The desired state for each device"""
  56. for i, state in enumerate(new_states):
  57. set_rng_state(state, i)
  58. def manual_seed(seed: int) -> None:
  59. r"""Sets the seed for generating random numbers for the current GPU.
  60. It's safe to call this function if CUDA is not available; in that
  61. case, it is silently ignored.
  62. Args:
  63. seed (int): The desired seed.
  64. .. warning::
  65. If you are working with a multi-GPU model, this function is insufficient
  66. to get determinism. To seed all GPUs, use :func:`manual_seed_all`.
  67. """
  68. seed = int(seed)
  69. def cb():
  70. idx = current_device()
  71. default_generator = torch.cuda.default_generators[idx]
  72. default_generator.manual_seed(seed)
  73. _lazy_call(cb, seed=True)
  74. def manual_seed_all(seed: int) -> None:
  75. r"""Sets the seed for generating random numbers on all GPUs.
  76. It's safe to call this function if CUDA is not available; in that
  77. case, it is silently ignored.
  78. Args:
  79. seed (int): The desired seed.
  80. """
  81. seed = int(seed)
  82. def cb():
  83. for i in range(device_count()):
  84. default_generator = torch.cuda.default_generators[i]
  85. default_generator.manual_seed(seed)
  86. _lazy_call(cb, seed_all=True)
  87. def seed() -> None:
  88. r"""Sets the seed for generating random numbers to a random number for the current GPU.
  89. It's safe to call this function if CUDA is not available; in that
  90. case, it is silently ignored.
  91. .. warning::
  92. If you are working with a multi-GPU model, this function will only initialize
  93. the seed on one GPU. To initialize all GPUs, use :func:`seed_all`.
  94. """
  95. def cb():
  96. idx = current_device()
  97. default_generator = torch.cuda.default_generators[idx]
  98. default_generator.seed()
  99. _lazy_call(cb)
  100. def seed_all() -> None:
  101. r"""Sets the seed for generating random numbers to a random number on all GPUs.
  102. It's safe to call this function if CUDA is not available; in that
  103. case, it is silently ignored.
  104. """
  105. def cb():
  106. random_seed = 0
  107. seeded = False
  108. for i in range(device_count()):
  109. default_generator = torch.cuda.default_generators[i]
  110. if not seeded:
  111. default_generator.seed()
  112. random_seed = default_generator.initial_seed()
  113. seeded = True
  114. else:
  115. default_generator.manual_seed(random_seed)
  116. _lazy_call(cb)
  117. def initial_seed() -> int:
  118. r"""Returns the current random seed of the current GPU.
  119. .. warning::
  120. This function eagerly initializes CUDA.
  121. """
  122. _lazy_init()
  123. idx = current_device()
  124. default_generator = torch.cuda.default_generators[idx]
  125. return default_generator.initial_seed()