BinaryInternal.h 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. // DON'T include this except from Binary*.cu files. It should not leak into
  2. // headers.
  3. #pragma once
  4. #define TORCH_ASSERT_NO_OPERATORS
  5. #include <ATen/AccumulateType.h>
  6. #include <ATen/Dispatch.h>
  7. #include <ATen/native/BinaryOps.h>
  8. #include <ATen/native/DispatchStub.h>
  9. #include <ATen/native/TensorIterator.h>
  10. #include <c10/cuda/CUDAGuard.h>
  11. #include <c10/cuda/CUDAMathCompat.h>
  12. #include <c10/util/TypeSafeSignMath.h>
  13. #include <ATen/native/cuda/JitLoops.cuh>
  14. #include <ATen/native/cuda/Loops.cuh>
  15. #include <type_traits>
  16. namespace at {
  17. namespace native {
  18. namespace binary_internal {
  19. template <typename scalar_t>
  20. struct DivFunctor {
  21. __device__ scalar_t operator()(scalar_t a, scalar_t b) const {
  22. return a / b;
  23. }
  24. };
  25. template <typename T>
  26. struct MulFunctor {
  27. __device__ T operator()(T a, T b) const {
  28. return a * b;
  29. }
  30. };
  31. // Workaround for the error: '*' in boolean context, suggest '&&' instead
  32. // [-Werror=int-in-bool-context]
  33. template <>
  34. struct MulFunctor<bool> {
  35. __device__ bool operator()(bool a, bool b) const {
  36. return a && b;
  37. }
  38. };
  39. void div_true_kernel_cuda(TensorIteratorBase& iter);
  40. void div_trunc_kernel_cuda(TensorIteratorBase& iter);
  41. } // namespace binary_internal
  42. } // namespace native
  43. } // namespace at