CompositeViewCopyKernels.cpp 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
  2. // ${generated_comment}
  3. #include <ATen/InferSize.h>
  4. #include <ATen/Tensor.h>
  5. #include <ATen/native/Resize.h>
  6. #ifndef AT_PER_OPERATOR_HEADERS
  7. #include <ATen/Operators.h>
  8. #else
  9. #include <ATen/ops/clone.h>
  10. $ops_headers
  11. #endif
  12. namespace at {
  13. namespace native {
  14. // This file contains a number of kernels for aten functions that are fully code-generated.
  15. // TODO: rename this file to something more generic.
  16. at::Tensor clone_arg(const at::Tensor& t) {
  17. return t.clone();
  18. }
  19. std::vector<at::Tensor> clone_arg(const at::TensorList& t_list) {
  20. std::vector<at::Tensor> out(t_list.size());
  21. for (const auto& i : c10::irange(t_list.size())) {
  22. out[i] = t_list[i].clone();
  23. }
  24. return out;
  25. }
  26. // duped with gen_resize_out_helper from structured kernels
  27. void copy_arg(const at::Tensor& dst, const at::Tensor& src) {
  28. TORCH_CHECK(src.dtype() == dst.dtype(),
  29. "Expected out tensor to have dtype ", src.dtype(), ", but got ", dst.dtype(), " instead");
  30. TORCH_CHECK(src.device() == dst.device(),
  31. "Expected out tensor to have device ", src.device(), ", but got ", dst.device(), " instead");
  32. dst.copy_(src);
  33. }
  34. void copy_arg(const at::TensorList& dst, const at::TensorList& src) {
  35. TORCH_INTERNAL_ASSERT(dst.size() == src.size());
  36. for (const auto& i : c10::irange(dst.size())) {
  37. copy_arg(dst[i], src[i]);
  38. }
  39. }
  40. // TODO: this doesn't handle restriding empty tensors correctly; see
  41. // gen_resize_out_helper for the correct algorithm
  42. void resize_out_helper(const at::Tensor& dst, const at::Tensor& src) {
  43. at::native::resize_output(dst, src.sizes());
  44. }
  45. void resize_out_helper(const at::TensorList& dst, const at::TensorList& src) {
  46. TORCH_INTERNAL_ASSERT(dst.size() == src.size());
  47. for (const auto& i : c10::irange(dst.size())) {
  48. at::native::resize_output(dst[i], src[i].sizes());
  49. }
  50. }
  51. ${CompositeViewCopyKernel_Definitions}
  52. ${GeneratedCompositeFunctional_Definitions}
  53. ${GeneratedCompositeOut_Definitions}
  54. } // namespace native
  55. } // namespace at