LossMulti.h 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/AccumulateType.h>
  4. #include <ATen/Dispatch.h>
  5. #include <ATen/TensorUtils.h>
  6. namespace at { namespace native {
  7. namespace {
  8. static C10_UNUSED void multilabel_margin_loss_shape_check(
  9. int64_t& nframe,
  10. int64_t& dim,
  11. const int64_t& ndims,
  12. TensorArg& target_arg,
  13. const Tensor& input,
  14. const Tensor& target) {
  15. bool valid_inputs = (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0;
  16. TORCH_CHECK(
  17. valid_inputs,
  18. "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
  19. input.sizes());
  20. if (ndims <= 1) {
  21. nframe = 1;
  22. dim = ndims == 0 ? 1 : input.size(0);
  23. TORCH_CHECK(
  24. valid_inputs && target.dim() <= 1 && target.numel() == dim,
  25. "inconsistent size ",
  26. target.sizes(),
  27. " for ",
  28. target_arg);
  29. } else {
  30. nframe = input.size(0);
  31. dim = input.size(1);
  32. TORCH_CHECK(
  33. valid_inputs && target.dim() == 2 && target.size(0) == nframe &&
  34. target.size(1) == dim,
  35. "inconsistent size ",
  36. target.sizes(),
  37. " for ",
  38. target_arg);
  39. }
  40. }
  41. static C10_UNUSED void multi_margin_loss_shape_check(
  42. int64_t& nframe,
  43. int64_t& dim,
  44. const int64_t& ndims,
  45. TensorArg& target_arg,
  46. const Tensor& input,
  47. const Tensor& target) {
  48. bool valid_inputs = (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0;
  49. if (ndims <= 1) {
  50. nframe = 1;
  51. dim = ndims == 0 ? 1 : input.size(0);
  52. } else {
  53. nframe = input.size(0);
  54. dim = input.size(1);
  55. }
  56. TORCH_CHECK(
  57. valid_inputs,
  58. "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
  59. input.sizes());
  60. TORCH_CHECK(
  61. valid_inputs && target.dim() <= 1 && target.numel() == nframe,
  62. "inconsistent target size, got: ",
  63. target.sizes());
  64. }
  65. } // anonymous namespace
  66. }} // namespace at::native