123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- import functools
- import torch
- # API to query cuda properties that will work in a triton compile process
- # that cannot use the GPU APIs (due to processing fork() and initialization
- # time issues). Properties are recorded in the main process before
- # we fork the workers.
- @functools.lru_cache(None)
- def _properties():
- if not torch.cuda.is_available():
- return {}
- try:
- return {
- i: torch.cuda.get_device_properties(i)
- for i in range(torch.cuda.device_count())
- }
- except RuntimeError:
- return {}
- _compile_worker_current_device = None
- def set_compiler_worker_current_device(device):
- global _compile_worker_current_device
- _compile_worker_current_device = device
- def current_device():
- if _compile_worker_current_device is not None:
- return _compile_worker_current_device
- return torch.cuda.current_device()
- def _device(device):
- if device is not None:
- if isinstance(device, torch.device):
- assert device.type == "cuda"
- device = device.index
- return device
- return current_device()
- def get_device_properties(device=None):
- return _properties()[_device(device)]
- def get_device_capability(device=None):
- p = get_device_properties(device)
- return p.major, p.minor
|