CUDAHooks.h 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. #pragma once
  2. #include <ATen/detail/CUDAHooksInterface.h>
  3. #include <ATen/Generator.h>
  4. #include <c10/util/Optional.h>
  5. // TODO: No need to have this whole header, we can just put it all in
  6. // the cpp file
  7. namespace at { namespace cuda { namespace detail {
  8. // Set the callback to initialize Magma, which is set by
  9. // torch_cuda_cu. This indirection is required so magma_init is called
  10. // in the same library where Magma will be used.
  11. TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)());
  12. // The real implementation of CUDAHooksInterface
  13. struct CUDAHooks : public at::CUDAHooksInterface {
  14. CUDAHooks(at::CUDAHooksArgs) {}
  15. void initCUDA() const override;
  16. Device getDeviceFromPtr(void* data) const override;
  17. bool isPinnedPtr(void* data) const override;
  18. const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override;
  19. bool hasCUDA() const override;
  20. bool hasMAGMA() const override;
  21. bool hasCuDNN() const override;
  22. bool hasCuSOLVER() const override;
  23. bool hasROCM() const override;
  24. const at::cuda::NVRTC& nvrtc() const override;
  25. int64_t current_device() const override;
  26. bool hasPrimaryContext(int64_t device_index) const override;
  27. Allocator* getCUDADeviceAllocator() const override;
  28. Allocator* getPinnedMemoryAllocator() const override;
  29. bool compiledWithCuDNN() const override;
  30. bool compiledWithMIOpen() const override;
  31. bool supportsDilatedConvolutionWithCuDNN() const override;
  32. bool supportsDepthwiseConvolutionWithCuDNN() const override;
  33. bool supportsBFloat16ConvolutionWithCuDNNv8() const override;
  34. bool hasCUDART() const override;
  35. long versionCUDART() const override;
  36. long versionCuDNN() const override;
  37. std::string showConfig() const override;
  38. double batchnormMinEpsilonCuDNN() const override;
  39. int64_t cuFFTGetPlanCacheMaxSize(int64_t device_index) const override;
  40. void cuFFTSetPlanCacheMaxSize(int64_t device_index, int64_t max_size) const override;
  41. int64_t cuFFTGetPlanCacheSize(int64_t device_index) const override;
  42. void cuFFTClearPlanCache(int64_t device_index) const override;
  43. int getNumGPUs() const override;
  44. void deviceSynchronize(int64_t device_index) const override;
  45. };
  46. }}} // at::cuda::detail