ThrustAllocator.h 496 B

12345678910111213141516171819202122232425
  1. #pragma once
  2. #include <cstddef>
  3. #include <c10/cuda/CUDACachingAllocator.h>
  4. namespace at {
  5. namespace cuda {
  6. /// Allocator for Thrust to re-route its internal device allocations
  7. /// to the THC allocator
  8. class ThrustAllocator {
  9. public:
  10. typedef char value_type;
  11. char* allocate(std::ptrdiff_t size) {
  12. return static_cast<char*>(c10::cuda::CUDACachingAllocator::raw_alloc(size));
  13. }
  14. void deallocate(char* p, size_t size) {
  15. c10::cuda::CUDACachingAllocator::raw_delete(p);
  16. }
  17. };
  18. }
  19. }