vol2col.h 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. #pragma once
  2. #include <cstring>
  3. namespace at {
  4. namespace native {
  5. template <typename T>
  6. static void vol2col(
  7. const T* data_vol,
  8. const int64_t channels,
  9. const int64_t depth,
  10. const int64_t height,
  11. const int64_t width,
  12. const int64_t depth_col,
  13. const int64_t height_col,
  14. const int64_t width_col,
  15. const int64_t kT,
  16. const int64_t kernel_height,
  17. const int64_t kernel_width,
  18. const int64_t pT,
  19. const int64_t pH,
  20. const int64_t pW,
  21. const int64_t dT,
  22. const int64_t dH,
  23. const int64_t dW,
  24. const int64_t dilationT,
  25. const int64_t dilationH,
  26. const int64_t dilationW,
  27. T* data_col) {
  28. int64_t c, t, h, w;
  29. int64_t channels_col = channels * kT * kernel_height * kernel_width;
  30. for (c = 0; c < channels_col; ++c) {
  31. int64_t w_offset = c % kernel_width;
  32. int64_t h_offset = (c / kernel_width) % kernel_height;
  33. int64_t t_offset = (c / kernel_width / kernel_height) % kT;
  34. int64_t c_vol = c / kT / kernel_height / kernel_width;
  35. for (t = 0; t < depth_col; ++t) {
  36. int64_t t_pad = t * dT - pT + t_offset * dilationT;
  37. for (h = 0; h < height_col; ++h) {
  38. int64_t h_pad = h * dH - pH + h_offset * dilationH;
  39. for (w = 0; w < width_col; ++w) {
  40. int64_t w_pad = w * dW - pW + w_offset * dilationW;
  41. if (t_pad >= 0 && t_pad < depth && h_pad >= 0 && h_pad < height &&
  42. w_pad >= 0 && w_pad < width)
  43. data_col[((c * depth_col + t) * height_col + h) * width_col + w] =
  44. data_vol
  45. [((c_vol * depth + t_pad) * height + h_pad) * width +
  46. w_pad];
  47. else
  48. data_col[((c * depth_col + t) * height_col + h) * width_col + w] =
  49. 0;
  50. }
  51. }
  52. }
  53. }
  54. }
  55. template <typename T>
  56. static void col2vol(
  57. const T* data_col,
  58. const int64_t channels,
  59. const int64_t depth,
  60. const int64_t height,
  61. const int64_t width,
  62. const int64_t out_depth,
  63. const int64_t out_height,
  64. const int64_t out_width,
  65. const int64_t kT,
  66. const int64_t kernel_height,
  67. const int64_t kernel_width,
  68. const int64_t pT,
  69. const int64_t pH,
  70. const int64_t pW,
  71. const int64_t dT,
  72. const int64_t dH,
  73. const int64_t dW,
  74. const int64_t dilationT,
  75. const int64_t dilationH,
  76. const int64_t dilationW,
  77. T* data_vol) {
  78. int64_t c, t, h, w;
  79. memset(data_vol, 0, sizeof(T) * depth * height * width * channels);
  80. int64_t depth_col = out_depth;
  81. int64_t height_col = out_height;
  82. int64_t width_col = out_width;
  83. int64_t channels_col = channels * kT * kernel_height * kernel_width;
  84. for (c = 0; c < channels_col; ++c) {
  85. int64_t w_offset = c % kernel_width;
  86. int64_t h_offset = (c / kernel_width) % kernel_height;
  87. int64_t t_offset = (c / kernel_width / kernel_height) % kT;
  88. int64_t c_vol = c / kT / kernel_height / kernel_width;
  89. for (t = 0; t < depth_col; ++t) {
  90. int64_t t_pad = t * dT - pT + t_offset * dilationT;
  91. for (h = 0; h < height_col; ++h) {
  92. int64_t h_pad = h * dH - pH + h_offset * dilationH;
  93. for (w = 0; w < width_col; ++w) {
  94. int64_t w_pad = w * dW - pW + w_offset * dilationW;
  95. if (t_pad >= 0 && t_pad < depth && h_pad >= 0 && h_pad < height &&
  96. w_pad >= 0 && w_pad < width)
  97. data_vol
  98. [((c_vol * depth + t_pad) * height + h_pad) * width + w_pad] +=
  99. data_col
  100. [((c * depth_col + t) * height_col + h) * width_col + w];
  101. }
  102. }
  103. }
  104. }
  105. }
  106. } // namespace native
  107. } // namespace at