ScalarOps.h 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. #pragma once
  2. #include <ATen/Tensor.h>
  3. #include <c10/core/Scalar.h>
  4. #ifndef AT_PER_OPERATOR_HEADERS
  5. #include <ATen/Functions.h>
  6. #else
  7. #include <ATen/ops/scalar_tensor.h>
  8. #endif
  9. namespace at {
  10. namespace detail {
  11. // When filling a number to 1-element CPU tensor, we want to skip
  12. // everything but manipulate data ptr directly.
  13. // Ideally this fast pass should be implemented in TensorIterator,
  14. // but we also want to skip compute_types which in not avoidable
  15. // in TensorIterator for now.
  16. Tensor& scalar_fill(Tensor& self, const Scalar& value);
  17. TORCH_API Tensor scalar_tensor_static(
  18. const Scalar& s,
  19. c10::optional<ScalarType> dtype_opt,
  20. c10::optional<Device> device_opt);
  21. } // namespace detail
  22. } // namespace at
  23. // This is in the c10 namespace because we use ADL to find the functions in it.
  24. namespace c10 {
  25. // FIXME: this should be (and was) Scalar::toTensor, but there is currently no
  26. // way to implement this without going through Derived Types (which are not part
  27. // of core).
  28. inline at::Tensor scalar_to_tensor(
  29. const Scalar& s,
  30. const Device device = at::kCPU) {
  31. // This is the fast track we have for CPU scalar tensors.
  32. if (device == at::kCPU) {
  33. if (s.isFloatingPoint()) {
  34. return at::detail::scalar_tensor_static(s, at::kDouble, at::kCPU);
  35. } else if (s.isComplex()) {
  36. return at::detail::scalar_tensor_static(s, at::kComplexDouble, at::kCPU);
  37. } else if (s.isBoolean()) {
  38. return at::detail::scalar_tensor_static(s, at::kBool, at::kCPU);
  39. } else {
  40. AT_ASSERT(s.isIntegral(false));
  41. return at::detail::scalar_tensor_static(s, at::kLong, at::kCPU);
  42. }
  43. }
  44. if (s.isFloatingPoint()) {
  45. return at::scalar_tensor(s, at::device(device).dtype(at::kDouble));
  46. } else if (s.isBoolean()) {
  47. return at::scalar_tensor(s, at::device(device).dtype(at::kBool));
  48. } else if (s.isComplex()) {
  49. return at::scalar_tensor(s, at::device(device).dtype(at::kComplexDouble));
  50. } else {
  51. AT_ASSERT(s.isIntegral(false));
  52. return at::scalar_tensor(s, at::device(device).dtype(at::kLong));
  53. }
  54. }
  55. } // namespace c10
  56. namespace at {
  57. namespace native {
  58. inline Tensor wrapped_scalar_tensor(
  59. const Scalar& scalar,
  60. const Device device = at::kCPU) {
  61. auto tensor = scalar_to_tensor(scalar, device);
  62. tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
  63. return tensor;
  64. }
  65. } // namespace native
  66. } // namespace at