random.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import contextlib
  2. from typing import Generator
  3. import warnings
  4. from torch._C import default_generator
  5. import torch
  6. def set_rng_state(new_state: torch.Tensor) -> None:
  7. r"""Sets the random number generator state.
  8. .. note: This function only works for CPU. For CUDA, please use
  9. torch.manual_seed(seed), which works for both CPU and CUDA.
  10. Args:
  11. new_state (torch.ByteTensor): The desired state
  12. """
  13. default_generator.set_state(new_state)
  14. def get_rng_state() -> torch.Tensor:
  15. r"""Returns the random number generator state as a `torch.ByteTensor`."""
  16. return default_generator.get_state()
  17. def manual_seed(seed) -> torch._C.Generator:
  18. r"""Sets the seed for generating random numbers. Returns a
  19. `torch.Generator` object.
  20. Args:
  21. seed (int): The desired seed. Value must be within the inclusive range
  22. `[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. Otherwise, a RuntimeError
  23. is raised. Negative inputs are remapped to positive values with the formula
  24. `0xffff_ffff_ffff_ffff + seed`.
  25. """
  26. seed = int(seed)
  27. import torch.cuda
  28. if not torch.cuda._is_in_bad_fork():
  29. torch.cuda.manual_seed_all(seed)
  30. import torch.mps
  31. if not torch.mps._is_in_bad_fork():
  32. torch.mps.manual_seed(seed)
  33. return default_generator.manual_seed(seed)
  34. def seed() -> int:
  35. r"""Sets the seed for generating random numbers to a non-deterministic
  36. random number. Returns a 64 bit number used to seed the RNG.
  37. """
  38. seed = default_generator.seed()
  39. import torch.cuda
  40. if not torch.cuda._is_in_bad_fork():
  41. torch.cuda.manual_seed_all(seed)
  42. import torch.mps
  43. if not torch.mps._is_in_bad_fork():
  44. torch.mps.manual_seed(seed)
  45. return seed
  46. def initial_seed() -> int:
  47. r"""Returns the initial seed for generating random numbers as a
  48. Python `long`.
  49. """
  50. return default_generator.initial_seed()
  51. _fork_rng_warned_already = False
  52. @contextlib.contextmanager
  53. def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices") -> Generator:
  54. """
  55. Forks the RNG, so that when you return, the RNG is reset
  56. to the state that it was previously in.
  57. Args:
  58. devices (iterable of CUDA IDs): CUDA devices for which to fork
  59. the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates
  60. on all devices, but will emit a warning if your machine has a lot
  61. of devices, since this function will run very slowly in that case.
  62. If you explicitly specify devices, this warning will be suppressed
  63. enabled (bool): if ``False``, the RNG is not forked. This is a convenience
  64. argument for easily disabling the context manager without having
  65. to delete it and unindent your Python code under it.
  66. """
  67. import torch.cuda
  68. global _fork_rng_warned_already
  69. # Internal arguments:
  70. # _caller: the function which called fork_rng, which the user used
  71. # _devices_kw: the devices keyword of _caller
  72. if not enabled:
  73. yield
  74. return
  75. if devices is None:
  76. num_devices = torch.cuda.device_count()
  77. if num_devices > 1 and not _fork_rng_warned_already:
  78. warnings.warn(
  79. ("CUDA reports that you have {num_devices} available devices, and you "
  80. "have used {caller} without explicitly specifying which devices are being used. "
  81. "For safety, we initialize *every* CUDA device by default, which "
  82. "can be quite slow if you have a lot of GPUs. If you know that you are only "
  83. "making use of a few CUDA devices, set the environment variable CUDA_VISIBLE_DEVICES "
  84. "or the '{devices_kw}' keyword argument of {caller} with the set of devices "
  85. "you are actually using. For example, if you are using CPU only, "
  86. "set CUDA_VISIBLE_DEVICES= or devices=[]; if you are using "
  87. "GPU 0 only, set CUDA_VISIBLE_DEVICES=0 or devices=[0]. To initialize "
  88. "all devices and suppress this warning, set the '{devices_kw}' keyword argument "
  89. "to `range(torch.cuda.device_count())`."
  90. ).format(num_devices=num_devices, caller=_caller, devices_kw=_devices_kw))
  91. _fork_rng_warned_already = True
  92. devices = list(range(num_devices))
  93. else:
  94. # Protect against user passing us a generator; we need to traverse this
  95. # multiple times but a generator will be exhausted upon first traversal
  96. devices = list(devices)
  97. cpu_rng_state = torch.get_rng_state()
  98. gpu_rng_states = []
  99. for device in devices:
  100. gpu_rng_states.append(torch.cuda.get_rng_state(device))
  101. try:
  102. yield
  103. finally:
  104. torch.set_rng_state(cpu_rng_state)
  105. for device, gpu_rng_state in zip(devices, gpu_rng_states):
  106. torch.cuda.set_rng_state(gpu_rng_state, device)