TensorSubclassLikeUtils.h 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. #pragma once
  2. #include <ATen/core/List.h>
  3. #include <ATen/core/Tensor.h>
  4. #include <c10/core/impl/TorchDispatchModeTLS.h>
  5. #ifndef AT_PER_OPERATOR_HEADERS
  6. #include <ATen/Functions.h>
  7. #else
  8. #include <ATen/ops/equal.h>
  9. #endif
  10. namespace at {
  11. // Note [Tensor-subclass-like Tensors]
  12. // Tensor-subclass-like is defined as:
  13. // - a Tensor subclass (via __torch_dispatch__ in Python or extending
  14. // TensorImpl in C++)
  15. // - anything else that shares the same perils as Tensor subclasses.
  16. // For example, many Tensor subclasses do not have storage and meta Tensors
  17. // do not have storage either, so meta Tensors belong here.
  18. //
  19. // We should ensure that PyTorch internals supports Tensor-subclass-like
  20. // objects. In particular, Tensor-subclass-like objects struggle with two
  21. // classes of operations that are problematic for Tensor subclasses:
  22. // 1. Because some Tensor subclasses do not have storage, .item() or
  23. // .data_ptr() calls are not good.
  24. // 2. Certain in-place operations can eliminate the typing of the Tensor
  25. // subclass. For example:
  26. // >>> torch.zeros(input.sizes(), grad.options()).diag().copy_(input)
  27. // If input is a Tensor subclass, then the above ends up either erroring out
  28. // or returning a regular non-Tensor-subclass Tensor!
  29. constexpr auto kFunctorchWrappedTensors = DispatchKeySet(
  30. {DispatchKey::FuncTorchGradWrapper,
  31. DispatchKey::FuncTorchBatched,
  32. DispatchKey::Functionalize});
  33. constexpr auto kTensorSubclassLike =
  34. kFunctorchWrappedTensors |
  35. DispatchKeySet(
  36. {// WARNING: DO NOT put combined backend component + functionality keys
  37. // here, you will incorrectly always match on the functionality key
  38. // no matter the backend component
  39. DispatchKey::Batched,
  40. DispatchKey::Sparse,
  41. DispatchKey::SparseCsrCPU,
  42. DispatchKey::SparseCsrCUDA,
  43. DispatchKey::Python}) |
  44. DispatchKeySet(BackendComponent::MetaBit);
  45. inline bool isTensorSubclassLike(const Tensor& tensor) {
  46. if (c10::impl::dispatch_mode_enabled())
  47. return true;
  48. auto key_set = tensor.unsafeGetTensorImpl()->key_set();
  49. return !(key_set & kTensorSubclassLike).empty();
  50. }
  51. inline bool areAnyTensorSubclassLike(TensorList tensors) {
  52. if (c10::impl::dispatch_mode_enabled())
  53. return true;
  54. return std::any_of(tensors.begin(), tensors.end(), isTensorSubclassLike);
  55. }
  56. inline bool areAnyOptionalTensorSubclassLike(
  57. const c10::List<c10::optional<Tensor>>& tensors) {
  58. if (c10::impl::dispatch_mode_enabled())
  59. return true;
  60. return std::any_of(
  61. tensors.begin(), tensors.end(), [](const optional<Tensor>& opt_tensor) {
  62. return (
  63. opt_tensor.has_value() && isTensorSubclassLike(opt_tensor.value()));
  64. });
  65. }
  66. // Helper function to deal testing truthfulness of a scalar tensor
  67. // in a Composite Compliant manner.
  68. // NOTE: This function expects a scalar tensor of boolean dtype.
  69. // Eg.
  70. // Non-Composite Compliant Pattern : (t == 0).all().item<bool>()
  71. // Composite Compliant Patter : is_salar_tensor_true((t == 0).all())
  72. inline bool is_scalar_tensor_true(const Tensor& t) {
  73. TORCH_INTERNAL_ASSERT(t.dim() == 0)
  74. TORCH_INTERNAL_ASSERT(t.scalar_type() == kBool)
  75. return at::equal(t, t.new_ones({}, t.options()));
  76. }
  77. } // namespace at