cat_meta.h 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. #pragma once
  2. // @generated by torchgen/gen.py from NativeMetaFunction.h
  3. #include <c10/core/Scalar.h>
  4. #include <c10/core/Storage.h>
  5. #include <c10/core/TensorOptions.h>
  6. #include <c10/util/Deprecated.h>
  7. #include <c10/util/Optional.h>
  8. #include <c10/core/QScheme.h>
  9. #include <ATen/core/Reduction.h>
  10. #include <ATen/TensorIterator.h>
  11. #include <ATen/TensorMeta.h>
  12. #include <tuple>
  13. #include <vector>
  14. namespace at {
  15. namespace meta {
  16. struct TORCH_API structured_cat : public at::impl::MetaBase {
  17. template <bool DIM = false, bool VALID = false, bool ALL_CONTIGUOUS = false, bool ALL_SAME_DTYPE = false, bool ALL_SAME_SIZES_AND_STRIDE = false, bool MEMORY_FORMAT = false>
  18. struct TORCH_API precompute_out {
  19. precompute_out<true, VALID, ALL_CONTIGUOUS, ALL_SAME_DTYPE, ALL_SAME_SIZES_AND_STRIDE, MEMORY_FORMAT> set_dim(int64_t value) {
  20. static_assert(DIM == false, "dim already set");
  21. precompute_out<true, VALID, ALL_CONTIGUOUS, ALL_SAME_DTYPE, ALL_SAME_SIZES_AND_STRIDE, MEMORY_FORMAT> ret;
  22. ret.dim = value;
  23. ret.valid = this->valid;
  24. ret.all_contiguous = this->all_contiguous;
  25. ret.all_same_dtype = this->all_same_dtype;
  26. ret.all_same_sizes_and_stride = this->all_same_sizes_and_stride;
  27. ret.memory_format = this->memory_format;
  28. return ret;
  29. }
  30. precompute_out<DIM, true, ALL_CONTIGUOUS, ALL_SAME_DTYPE, ALL_SAME_SIZES_AND_STRIDE, MEMORY_FORMAT> set_valid(int64_t value) {
  31. static_assert(VALID == false, "valid already set");
  32. precompute_out<DIM, true, ALL_CONTIGUOUS, ALL_SAME_DTYPE, ALL_SAME_SIZES_AND_STRIDE, MEMORY_FORMAT> ret;
  33. ret.dim = this->dim;
  34. ret.valid = value;
  35. ret.all_contiguous = this->all_contiguous;
  36. ret.all_same_dtype = this->all_same_dtype;
  37. ret.all_same_sizes_and_stride = this->all_same_sizes_and_stride;
  38. ret.memory_format = this->memory_format;
  39. return ret;
  40. }
  41. precompute_out<DIM, VALID, true, ALL_SAME_DTYPE, ALL_SAME_SIZES_AND_STRIDE, MEMORY_FORMAT> set_all_contiguous(bool value) {
  42. static_assert(ALL_CONTIGUOUS == false, "all_contiguous already set");
  43. precompute_out<DIM, VALID, true, ALL_SAME_DTYPE, ALL_SAME_SIZES_AND_STRIDE, MEMORY_FORMAT> ret;
  44. ret.dim = this->dim;
  45. ret.valid = this->valid;
  46. ret.all_contiguous = value;
  47. ret.all_same_dtype = this->all_same_dtype;
  48. ret.all_same_sizes_and_stride = this->all_same_sizes_and_stride;
  49. ret.memory_format = this->memory_format;
  50. return ret;
  51. }
  52. precompute_out<DIM, VALID, ALL_CONTIGUOUS, true, ALL_SAME_SIZES_AND_STRIDE, MEMORY_FORMAT> set_all_same_dtype(bool value) {
  53. static_assert(ALL_SAME_DTYPE == false, "all_same_dtype already set");
  54. precompute_out<DIM, VALID, ALL_CONTIGUOUS, true, ALL_SAME_SIZES_AND_STRIDE, MEMORY_FORMAT> ret;
  55. ret.dim = this->dim;
  56. ret.valid = this->valid;
  57. ret.all_contiguous = this->all_contiguous;
  58. ret.all_same_dtype = value;
  59. ret.all_same_sizes_and_stride = this->all_same_sizes_and_stride;
  60. ret.memory_format = this->memory_format;
  61. return ret;
  62. }
  63. precompute_out<DIM, VALID, ALL_CONTIGUOUS, ALL_SAME_DTYPE, true, MEMORY_FORMAT> set_all_same_sizes_and_stride(bool value) {
  64. static_assert(ALL_SAME_SIZES_AND_STRIDE == false, "all_same_sizes_and_stride already set");
  65. precompute_out<DIM, VALID, ALL_CONTIGUOUS, ALL_SAME_DTYPE, true, MEMORY_FORMAT> ret;
  66. ret.dim = this->dim;
  67. ret.valid = this->valid;
  68. ret.all_contiguous = this->all_contiguous;
  69. ret.all_same_dtype = this->all_same_dtype;
  70. ret.all_same_sizes_and_stride = value;
  71. ret.memory_format = this->memory_format;
  72. return ret;
  73. }
  74. precompute_out<DIM, VALID, ALL_CONTIGUOUS, ALL_SAME_DTYPE, ALL_SAME_SIZES_AND_STRIDE, true> set_memory_format(at::MemoryFormat value) {
  75. static_assert(MEMORY_FORMAT == false, "memory_format already set");
  76. precompute_out<DIM, VALID, ALL_CONTIGUOUS, ALL_SAME_DTYPE, ALL_SAME_SIZES_AND_STRIDE, true> ret;
  77. ret.dim = this->dim;
  78. ret.valid = this->valid;
  79. ret.all_contiguous = this->all_contiguous;
  80. ret.all_same_dtype = this->all_same_dtype;
  81. ret.all_same_sizes_and_stride = this->all_same_sizes_and_stride;
  82. ret.memory_format = value;
  83. return ret;
  84. }
  85. int64_t dim;
  86. int64_t valid;
  87. bool all_contiguous;
  88. bool all_same_dtype;
  89. bool all_same_sizes_and_stride;
  90. at::MemoryFormat memory_format;
  91. };
  92. using meta_return_ty = precompute_out <true, true, true, true, true, true>;
  93. meta_return_ty meta(const at::ITensorListRef & tensors, int64_t dim);
  94. };
  95. } // namespace native
  96. } // namespace at