TensorOperators.h 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <c10/core/Scalar.h>
  4. #ifndef AT_PER_OPERATOR_HEADERS
  5. #include <ATen/Functions.h>
  6. #else
  7. #include <ATen/ops/empty_like.h>
  8. #endif
  9. #include <stdexcept>
  10. #include <string>
  11. namespace at {
  12. #define AT_FORALL_BINARY_OPS(_) \
  13. _(+, x.add(y), y.add(x)) \
  14. _(*, x.mul(y), y.mul(x)) \
  15. _(-, \
  16. x.sub(y), \
  17. ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).sub_(y)) \
  18. _(/, \
  19. x.div(y), \
  20. ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).div_(y)) \
  21. _(%, \
  22. x.remainder(y), \
  23. ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).remainder_(y)) \
  24. _(&, x.bitwise_and(y), y.bitwise_and(x)) \
  25. _(|, x.bitwise_or(y), y.bitwise_or(x)) \
  26. _(^, x.bitwise_xor(y), y.bitwise_xor(x)) \
  27. _(<, x.lt(y), y.gt(x)) \
  28. _(<=, x.le(y), y.ge(x)) \
  29. _(>, x.gt(y), y.lt(x)) \
  30. _(>=, x.ge(y), y.le(x)) \
  31. _(==, x.eq(y), y.eq(x)) \
  32. _(!=, x.ne(y), y.ne(x))
  33. #define DEFINE_OPERATOR(op, body, reverse_scalar_body) \
  34. static inline Tensor operator op(const Tensor& x, const Tensor& y) { \
  35. return body; \
  36. } \
  37. static inline Tensor operator op(const Tensor& x, const Scalar& y) { \
  38. return body; \
  39. } \
  40. static inline Tensor operator op(const Scalar& x, const Tensor& y) { \
  41. return reverse_scalar_body; \
  42. }
  43. AT_FORALL_BINARY_OPS(DEFINE_OPERATOR)
  44. #undef DEFINE_OPERATOR
  45. #undef AT_FORALL_BINARY_OPS
  46. } // namespace at