InferSize.h 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. #pragma once
  2. #include <ATen/DimVector.h>
  3. #include <c10/core/ScalarType.h>
  4. #include <c10/core/SymIntArrayRef.h>
  5. #include <c10/util/DimVector.h>
  6. #include <c10/util/Optional.h>
  7. #include <sstream>
  8. #include <vector>
  9. namespace at {
  10. // Infers the size of a dim with size -1, if it exists. Also checks that new
  11. // shape is compatible with the number of elements.
  12. //
  13. // templated to handle std::vector<int64_t> and DimVector use cases, see
  14. // below
  15. //
  16. template <typename InputArrayRef, typename NumelType, typename ResultVec>
  17. inline void infer_size_impl(
  18. InputArrayRef shape,
  19. NumelType numel,
  20. ResultVec& res) {
  21. NumelType newsize = 1;
  22. // N.B. this is an index, not a sym dim!
  23. auto infer_dim = c10::optional<int64_t>();
  24. for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
  25. if (shape[dim] == -1) {
  26. if (infer_dim) {
  27. throw std::runtime_error("only one dimension can be inferred");
  28. }
  29. infer_dim = dim;
  30. } else if (shape[dim] >= 0) {
  31. newsize *= shape[dim];
  32. } else {
  33. AT_ERROR("invalid shape dimension ", shape[dim]);
  34. }
  35. }
  36. if (numel == newsize || (infer_dim && newsize > 0 && numel % newsize == 0)) {
  37. if (infer_dim) {
  38. // We have a degree of freedom here to select the dimension size; follow
  39. // NumPy semantics and just bail. However, a nice error message is needed
  40. // because users often use `view` as a way to flatten & unflatten
  41. // dimensions and will otherwise be confused why
  42. // empty_tensor.view( 0, 0)
  43. // works yet
  44. // empty_tensor.view(-1, 0)
  45. // doesn't.
  46. TORCH_CHECK(
  47. newsize != 0,
  48. "cannot reshape tensor of 0 elements into shape ",
  49. shape,
  50. " because the unspecified dimension size -1 can be any "
  51. "value and is ambiguous");
  52. res[*infer_dim] = numel / newsize;
  53. }
  54. return;
  55. }
  56. std::ostringstream ss;
  57. ss << "shape '" << shape << "' is invalid for input of size " << numel;
  58. throw std::runtime_error(ss.str());
  59. }
  60. inline std::vector<int64_t> infer_size(IntArrayRef shape, int64_t numel) {
  61. auto res = shape.vec();
  62. infer_size_impl(shape, numel, res);
  63. return res;
  64. }
  65. inline at::DimVector infer_size_dv(IntArrayRef shape, int64_t numel) {
  66. auto res = at::DimVector(shape);
  67. infer_size_impl(shape, numel, res);
  68. return res;
  69. }
  70. inline at::SymDimVector infer_size_dv(
  71. c10::SymIntArrayRef shape,
  72. c10::SymInt numel) {
  73. auto res = at::SymDimVector(shape);
  74. infer_size_impl<c10::SymIntArrayRef, c10::SymInt, at::SymDimVector>(
  75. shape, std::move(numel), res);
  76. return res;
  77. }
  78. } // namespace at