import functools import torch # API to query cuda properties that will work in a triton compile process # that cannot use the GPU APIs (due to processing fork() and initialization # time issues). Properties are recorded in the main process before # we fork the workers. @functools.lru_cache(None) def _properties(): if not torch.cuda.is_available(): return {} try: return { i: torch.cuda.get_device_properties(i) for i in range(torch.cuda.device_count()) } except RuntimeError: return {} _compile_worker_current_device = None def set_compiler_worker_current_device(device): global _compile_worker_current_device _compile_worker_current_device = device def current_device(): if _compile_worker_current_device is not None: return _compile_worker_current_device return torch.cuda.current_device() def _device(device): if device is not None: if isinstance(device, torch.device): assert device.type == "cuda" device = device.index return device return current_device() def get_device_properties(device=None): return _properties()[_device(device)] def get_device_capability(device=None): p = get_device_properties(device) return p.major, p.minor