cuda_properties.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import functools
  2. import torch
  3. # API to query cuda properties that will work in a triton compile process
  4. # that cannot use the GPU APIs (due to processing fork() and initialization
  5. # time issues). Properties are recorded in the main process before
  6. # we fork the workers.
  7. @functools.lru_cache(None)
  8. def _properties():
  9. if not torch.cuda.is_available():
  10. return {}
  11. try:
  12. return {
  13. i: torch.cuda.get_device_properties(i)
  14. for i in range(torch.cuda.device_count())
  15. }
  16. except RuntimeError:
  17. return {}
  18. _compile_worker_current_device = None
  19. def set_compiler_worker_current_device(device):
  20. global _compile_worker_current_device
  21. _compile_worker_current_device = device
  22. def current_device():
  23. if _compile_worker_current_device is not None:
  24. return _compile_worker_current_device
  25. return torch.cuda.current_device()
  26. def _device(device):
  27. if device is not None:
  28. if isinstance(device, torch.device):
  29. assert device.type == "cuda"
  30. device = device.index
  31. return device
  32. return current_device()
  33. def get_device_properties(device=None):
  34. return _properties()[_device(device)]
  35. def get_device_capability(device=None):
  36. p = get_device_properties(device)
  37. return p.major, p.minor