layer_norm.h 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/native/DispatchStub.h>
  4. #include <c10/util/accumulate.h>
  5. namespace at {
  6. namespace native {
  7. namespace {
  8. C10_ALWAYS_INLINE std::pair<int64_t, int64_t> _check_layer_norm_inputs(
  9. const Tensor& input,
  10. IntArrayRef normalized_shape,
  11. const Tensor& weight /* optional */,
  12. const Tensor& bias /* optional */) {
  13. const int normalized_ndim = normalized_shape.size();
  14. TORCH_CHECK(
  15. normalized_ndim >= 1,
  16. "Expected normalized_shape to be at least 1-dimensional, i.e., ",
  17. "containing at least one element, but got normalized_shape = ",
  18. normalized_shape);
  19. TORCH_CHECK(
  20. !weight.defined() || weight.sizes().equals(normalized_shape),
  21. "Expected weight to be of same shape as normalized_shape, but got ",
  22. "weight of shape ",
  23. weight.sizes(),
  24. " and normalized_shape = ",
  25. normalized_shape);
  26. TORCH_CHECK(
  27. !bias.defined() || bias.sizes().equals(normalized_shape),
  28. "Expected bias to be of same shape as normalized_shape, but got ",
  29. "bias of shape ",
  30. bias.sizes(),
  31. " and normalized_shape = ",
  32. normalized_shape);
  33. const auto input_shape = input.sizes();
  34. const auto input_ndim = input.dim();
  35. if (input_ndim < normalized_ndim ||
  36. !input_shape.slice(input_ndim - normalized_ndim)
  37. .equals(normalized_shape)) {
  38. std::stringstream ss;
  39. ss << "Given normalized_shape=" << normalized_shape
  40. << ", expected input with shape [*";
  41. for (auto size : normalized_shape) {
  42. ss << ", " << size;
  43. }
  44. ss << "], but got input of size" << input_shape;
  45. AT_ERROR(ss.str());
  46. }
  47. const int axis = input_ndim - normalized_ndim;
  48. const int64_t M =
  49. c10::multiply_integers(input_shape.cbegin(), input_shape.cbegin() + axis);
  50. const int64_t N =
  51. c10::multiply_integers(input_shape.cbegin() + axis, input_shape.cend());
  52. return std::make_pair(M, N);
  53. }
  54. } // namespace
  55. void layer_norm_cpu_out(
  56. at::Tensor& out,
  57. const at::Tensor& input,
  58. const Tensor& gamma,
  59. const Tensor& beta,
  60. double eps,
  61. int64_t M,
  62. int64_t N);
  63. using forward_fn = void (*)(
  64. const Tensor& /* X */,
  65. const Tensor& /* gamma */,
  66. const Tensor& /* beta */,
  67. int64_t /* M */,
  68. int64_t /* N */,
  69. double /* eps */,
  70. Tensor* /* Y */,
  71. Tensor* /* mean */,
  72. Tensor* /* rstd */);
  73. using backward_fn = void (*)(
  74. const Tensor& /* dY */,
  75. const Tensor& /* X */,
  76. const Tensor& /* mean */,
  77. const Tensor& /* rstd */,
  78. const Tensor& /* gamma */,
  79. int64_t /* M */,
  80. int64_t /* N */,
  81. Tensor* /* dX */,
  82. Tensor* /* dgamma */,
  83. Tensor* /* dbeta */);
  84. DECLARE_DISPATCH(forward_fn, LayerNormKernel);
  85. DECLARE_DISPATCH(backward_fn, LayerNormBackwardKernel);
  86. } // namespace native
  87. } // namespace at