#pragma once #include #include namespace at { namespace cuda { /// Allocator for Thrust to re-route its internal device allocations /// to the THC allocator class ThrustAllocator { public: typedef char value_type; char* allocate(std::ptrdiff_t size) { return static_cast(c10::cuda::CUDACachingAllocator::raw_alloc(size)); } void deallocate(char* p, size_t size) { c10::cuda::CUDACachingAllocator::raw_delete(p); } }; } }