device_context.py 634 B

12345678910111213141516171819202122232425
  1. import threading
  2. from typing import Any, Dict
  3. import torch._C._lazy
  4. class DeviceContext:
  5. _CONTEXTS: Dict[str, Any] = dict()
  6. _CONTEXTS_LOCK = threading.Lock()
  7. def __init__(self, device):
  8. self.device = device
  9. def get_device_context(device=None):
  10. if device is None:
  11. device = torch._C._lazy._get_default_device_type()
  12. else:
  13. device = str(device)
  14. with DeviceContext._CONTEXTS_LOCK:
  15. devctx = DeviceContext._CONTEXTS.get(device, None)
  16. if devctx is None:
  17. devctx = DeviceContext(device)
  18. DeviceContext._CONTEXTS[device] = devctx
  19. return devctx