TriangularOpsUtils.h 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. #include <ATen/core/Tensor.h>
  2. #include <ATen/native/LinearAlgebraUtils.h>
  3. namespace at {
  4. namespace native {
  5. /*
  6. * Given batches of matrices with arbitrary batch dim,
  7. * computes the number of batches for Triu and Tril. This ignores stride 0 dimension
  8. */
  9. static inline int64_t batchCountTrilTriu(const Tensor& batched_matrices) {
  10. int64_t result = 1;
  11. for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
  12. if (batched_matrices.stride(i) != 0) {
  13. result *= batched_matrices.size(i);
  14. }
  15. }
  16. return result;
  17. }
  18. /* Checks a necessary property for the triu and tril implementations, hence the name.
  19. * Here batch contiguity is checked for tensors with greater than 4 dimensions.
  20. * Contiguous tensors and tensors with less than 3 dimensions pass this check
  21. */
  22. static inline std::tuple<bool, Tensor> checkTrilTriuBatchContiguous(const Tensor& tensor, bool allow_zero_stride) {
  23. // Complete contiguity is the most desired property, which is why
  24. // we return true if the tensor is contiguous
  25. if (tensor.is_contiguous()) {
  26. auto default_strides_for_size = batched_matrix_contiguous_strides(tensor.sizes());
  27. if (tensor.strides() == default_strides_for_size) {
  28. return std::make_tuple(true, tensor);
  29. } else {
  30. return std::make_tuple(false, tensor.as_strided(tensor.sizes(), default_strides_for_size));
  31. }
  32. }
  33. int64_t dims = tensor.dim();
  34. // Tensors with dimension less than 4 are handled by default
  35. if (allow_zero_stride && dims <= 3) {
  36. return std::make_tuple(true, tensor);
  37. }
  38. int64_t expected_stride = tensor.size(-1) * tensor.size(-2);
  39. for (int64_t i = dims - 3; i >= 0; i--) {
  40. // Skip trivial dimension;
  41. if (allow_zero_stride && i == 0 && (tensor.stride(i) == 0 || tensor.size(i) == 1)) {
  42. continue;
  43. }
  44. if (expected_stride != tensor.stride(i)) {
  45. return std::make_tuple(false, tensor.contiguous());
  46. }
  47. expected_stride *= tensor.size(i);
  48. }
  49. return std::make_tuple(true, tensor);
  50. }
  51. } // namespace native
  52. } // namespace at