123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- #pragma once
- #include <ATen/core/Tensor.h>
- #include <c10/core/Scalar.h>
- #ifndef AT_PER_OPERATOR_HEADERS
- #include <ATen/Functions.h>
- #else
- #include <ATen/ops/empty_like.h>
- #endif
- #include <stdexcept>
- #include <string>
- namespace at {
- #define AT_FORALL_BINARY_OPS(_) \
- _(+, x.add(y), y.add(x)) \
- _(*, x.mul(y), y.mul(x)) \
- _(-, \
- x.sub(y), \
- ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).sub_(y)) \
- _(/, \
- x.div(y), \
- ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).div_(y)) \
- _(%, \
- x.remainder(y), \
- ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).remainder_(y)) \
- _(&, x.bitwise_and(y), y.bitwise_and(x)) \
- _(|, x.bitwise_or(y), y.bitwise_or(x)) \
- _(^, x.bitwise_xor(y), y.bitwise_xor(x)) \
- _(<, x.lt(y), y.gt(x)) \
- _(<=, x.le(y), y.ge(x)) \
- _(>, x.gt(y), y.lt(x)) \
- _(>=, x.ge(y), y.le(x)) \
- _(==, x.eq(y), y.eq(x)) \
- _(!=, x.ne(y), y.ne(x))
- #define DEFINE_OPERATOR(op, body, reverse_scalar_body) \
- static inline Tensor operator op(const Tensor& x, const Tensor& y) { \
- return body; \
- } \
- static inline Tensor operator op(const Tensor& x, const Scalar& y) { \
- return body; \
- } \
- static inline Tensor operator op(const Scalar& x, const Tensor& y) { \
- return reverse_scalar_body; \
- }
- AT_FORALL_BINARY_OPS(DEFINE_OPERATOR)
- #undef DEFINE_OPERATOR
- #undef AT_FORALL_BINARY_OPS
- } // namespace at
|