ExpandBase.h 914 B

123456789101112131415161718192021222324252627282930
  1. #include <ATen/core/TensorBase.h>
  2. // Broadcasting utilities for working with TensorBase
  3. namespace at {
  4. namespace internal {
  5. TORCH_API TensorBase expand_slow_path(const TensorBase& self, IntArrayRef size);
  6. } // namespace internal
  7. inline c10::MaybeOwned<TensorBase> expand_size(
  8. const TensorBase& self,
  9. IntArrayRef size) {
  10. if (size.equals(self.sizes())) {
  11. return c10::MaybeOwned<TensorBase>::borrowed(self);
  12. }
  13. return c10::MaybeOwned<TensorBase>::owned(
  14. at::internal::expand_slow_path(self, size));
  15. }
  16. c10::MaybeOwned<TensorBase> expand_size(TensorBase&& self, IntArrayRef size) =
  17. delete;
  18. inline c10::MaybeOwned<TensorBase> expand_inplace(
  19. const TensorBase& tensor,
  20. const TensorBase& to_expand) {
  21. return expand_size(to_expand, tensor.sizes());
  22. }
  23. c10::MaybeOwned<TensorBase> expand_inplace(
  24. const TensorBase& tensor,
  25. TensorBase&& to_expand) = delete;
  26. } // namespace at