MPSHooksInterface.h 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. // Copyright © 2022 Apple Inc.
  2. #pragma once
  3. #include <c10/core/Allocator.h>
  4. #include <ATen/core/Generator.h>
  5. #include <c10/util/Exception.h>
  6. #include <c10/util/Registry.h>
  7. #include <cstddef>
  8. #include <functional>
  9. namespace at {
  10. class Context;
  11. }
  12. namespace at {
  13. struct TORCH_API MPSHooksInterface {
  14. virtual ~MPSHooksInterface() = default;
  15. // Initialize the MPS library state
  16. virtual void initMPS() const {
  17. AT_ERROR("Cannot initialize MPS without MPS backend.");
  18. }
  19. virtual bool hasMPS() const {
  20. return false;
  21. }
  22. virtual bool isOnMacOS13orNewer() const {
  23. AT_ERROR("MPS backend is not available.");
  24. }
  25. virtual const Generator& getDefaultMPSGenerator() const {
  26. AT_ERROR("Cannot get default MPS generator without MPS backend.");
  27. }
  28. virtual Allocator* getMPSDeviceAllocator() const {
  29. AT_ERROR("MPSDeviceAllocator requires MPS.");
  30. }
  31. virtual void deviceSynchronize() const {
  32. AT_ERROR("Cannot synchronize MPS device without MPS backend.");
  33. }
  34. virtual void emptyCache() const {
  35. AT_ERROR("Cannot execute emptyCache() without MPS backend.");
  36. }
  37. virtual size_t getCurrentAllocatedMemory() const {
  38. AT_ERROR("Cannot execute getCurrentAllocatedMemory() without MPS backend.");
  39. }
  40. virtual size_t getDriverAllocatedMemory() const {
  41. AT_ERROR("Cannot execute getDriverAllocatedMemory() without MPS backend.");
  42. }
  43. virtual void setMemoryFraction(double /*ratio*/) const {
  44. AT_ERROR("Cannot execute setMemoryFraction() without MPS backend.");
  45. }
  46. };
  47. struct TORCH_API MPSHooksArgs {};
  48. C10_DECLARE_REGISTRY(MPSHooksRegistry, MPSHooksInterface, MPSHooksArgs);
  49. #define REGISTER_MPS_HOOKS(clsname) \
  50. C10_REGISTER_CLASS(MPSHooksRegistry, clsname, clsname)
  51. namespace detail {
  52. TORCH_API const MPSHooksInterface& getMPSHooks();
  53. } // namespace detail
  54. } // namespace at