__init__.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435
  1. import torch
  2. from functools import lru_cache as _lru_cache
  3. __all__ = ["is_built", "is_available", "is_macos13_or_newer"]
  4. def is_built() -> bool:
  5. r"""Returns whether PyTorch is built with MPS support. Note that this
  6. doesn't necessarily mean MPS is available; just that if this PyTorch
  7. binary were run a machine with working MPS drivers and devices, we
  8. would be able to use it."""
  9. return torch._C.has_mps
  10. @_lru_cache()
  11. def is_available() -> bool:
  12. r"""Returns a bool indicating if MPS is currently available."""
  13. return torch._C._mps_is_available()
  14. @_lru_cache()
  15. def is_macos13_or_newer() -> bool:
  16. r"""Returns a bool indicating whether MPS is running on MacOS 13 or newer."""
  17. return torch._C._mps_is_on_macos_13_or_newer()
  18. # Register prims as implementation of var_mean and group_norm
  19. if is_built():
  20. from ...library import Library as _Library
  21. from ..._refs import var_mean as _var_mean, native_group_norm as _native_group_norm
  22. from ..._decomp.decompositions import native_group_norm_backward as _native_group_norm_backward
  23. _lib = _Library("aten", "IMPL")
  24. _lib.impl("var_mean.correction", _var_mean, "MPS")
  25. _lib.impl("native_group_norm", _native_group_norm, "MPS")
  26. _lib.impl("native_group_norm_backward", _native_group_norm_backward, "MPS")