12345678910111213141516171819202122232425 |
- import threading
- from typing import Any, Dict
- import torch._C._lazy
- class DeviceContext:
- _CONTEXTS: Dict[str, Any] = dict()
- _CONTEXTS_LOCK = threading.Lock()
- def __init__(self, device):
- self.device = device
- def get_device_context(device=None):
- if device is None:
- device = torch._C._lazy._get_default_device_type()
- else:
- device = str(device)
- with DeviceContext._CONTEXTS_LOCK:
- devctx = DeviceContext._CONTEXTS.get(device, None)
- if devctx is None:
- devctx = DeviceContext(device)
- DeviceContext._CONTEXTS[device] = devctx
- return devctx
|