TensorFactories.h 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/EmptyTensor.h>
  4. #include <ATen/TensorIterator.h>
  5. #include <ATen/native/DispatchStub.h>
  6. #ifndef AT_PER_OPERATOR_HEADERS
  7. #include <ATen/Functions.h>
  8. #else
  9. #include <ATen/ops/scalar_tensor.h>
  10. #endif
  11. namespace at { namespace native {
  12. // Different combinations of row, col, and offset can lead to two cases:
  13. //
  14. // Case 1 - Trapezoid (Triangle as a special case): row + offset <= col
  15. // Example A: offset > 0
  16. // 1 1 0 0 0
  17. // 1 1 1 0 0
  18. // 1 1 1 1 0
  19. // Example B: offset <= 0
  20. // 0 0 0
  21. // 1 0 0
  22. // 1 1 0
  23. // In this case, we calculate the number of elements in the first row and
  24. // last row of the tril respectively, and then compute the tril size.
  25. //
  26. // Case 2 - Trapezoid + Rectangle: row + offset > col
  27. // Example:
  28. // 1 1 0
  29. // 1 1 1
  30. // 1 1 1
  31. // In this case, we first calculate the size of top trapezoid, and then
  32. // calculate the size of the bottom rectangle.
  33. inline int64_t get_tril_size(int64_t row, int64_t col, int64_t offset) {
  34. // If either dimension is 0 then the there is no tril
  35. if (row == 0 || col == 0) {
  36. return 0;
  37. }
  38. // number of elements in the first row of the tril
  39. auto m_first_row = offset > 0 ?
  40. std::min<int64_t>(col, 1 + offset) : // upper bounded by col
  41. row + offset > 0; // either 0 or 1
  42. // number of elements in the last row of the tril, bounded by [0, col]
  43. auto m_last_row = std::max<int64_t>(0, std::min<int64_t>(col, row + offset));
  44. // number of rows, bounded by [0, row]
  45. auto n_row_all = std::max<int64_t>(0, std::min<int64_t>(row, row + offset));
  46. auto n_row_trapezoid = (m_last_row - m_first_row + 1);
  47. // calculate # of elements in the top trapezoid
  48. auto tril_size = (m_first_row + m_last_row) * n_row_trapezoid >> 1;
  49. // calculate # of elements in the bottom rectangle if there is any
  50. auto diff_row = n_row_all - n_row_trapezoid;
  51. if (diff_row > 0) {
  52. tril_size += diff_row * col;
  53. }
  54. return tril_size;
  55. }
  56. inline void check_args(
  57. int64_t row, int64_t col, c10::optional<Layout> layout_opt) {
  58. TORCH_CHECK(row >= 0, "row must be non-negative, got", row);
  59. TORCH_CHECK(col >= 0, "col must be non-negative, got", col);
  60. if (layout_opt.has_value()) {
  61. TORCH_CHECK(
  62. *layout_opt == at::kStrided,
  63. "only support layout=torch.strided, got",
  64. *layout_opt)
  65. }
  66. }
  67. using at::check_size_nonnegative;
  68. // assumes maximum value in created tensor is n-1 (e.g., torch.randperm(n))
  69. inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tensor) {
  70. // match defined() to behavior of checks below
  71. TORCH_CHECK(at::scalar_tensor(n>0?n-1:n, tensor.options()).defined(),
  72. "n is too large for result tensor type: '", tensor.toString(), "'");
  73. // Ensure sufficient precision for floating point representation.
  74. switch (tensor.scalar_type()) {
  75. case at::ScalarType::Half:
  76. TORCH_CHECK(n <= (int64_t(1) << 11) + 1, "n cannot be greater than 2049 for Half type.");
  77. break;
  78. case at::ScalarType::Float:
  79. TORCH_CHECK(n <= (int64_t(1) << 24) + 1, "n cannot be greater than 2^24+1 for Float type.");
  80. break;
  81. case at::ScalarType::Double: // Unlikely to happen, but doesn't hurt to check
  82. TORCH_CHECK(n <= (int64_t(1) << 53) + 1, "n cannot be greater than 2^53+1 for Double type.");
  83. break;
  84. default:
  85. break;
  86. }
  87. }
  88. // The ZeroTensor allocator ignores whatever allocation is requested and always
  89. // gives you nullptr
  90. struct ZeroTensorAllocator final : public at::Allocator {
  91. ZeroTensorAllocator(at::Device device) : device_(device) {};
  92. ~ZeroTensorAllocator() override = default;
  93. static void deleter(void* const pointer) {
  94. TORCH_INTERNAL_ASSERT(!pointer);
  95. }
  96. DataPtr allocate(const size_t /*nbytes*/) const override {
  97. return {nullptr, nullptr, &deleter, device_};
  98. }
  99. DeleterFnPtr raw_deleter() const override {
  100. return deleter;
  101. }
  102. at::Device device_;
  103. };
  104. using binary_fn = void (*)(TensorIterator&);
  105. DECLARE_DISPATCH(binary_fn, complex_stub);
  106. DECLARE_DISPATCH(binary_fn, polar_stub);
  107. } // namespace native
  108. } // namespace at