EmbeddingBackwardKernel.cuh 543 B

12345678910111213141516171819202122
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/cuda/Atomic.cuh>
  4. #include <ATen/cuda/CUDAContext.h>
  5. #include <ATen/TensorUtils.h>
  6. namespace at {
  7. namespace native {
  8. Tensor embedding_backward_cuda_kernel(
  9. const Tensor &grad,
  10. const Tensor &orig_indices,
  11. const Tensor &sorted_indices,
  12. const Tensor &count,
  13. int64_t num_weights,
  14. int padding_idx = -1,
  15. bool mode_mean = false,
  16. const Tensor &offset2bag = Tensor(),
  17. const Tensor &bag_size = Tensor(),
  18. const Tensor &per_sample_weights = Tensor());
  19. }}