conv1x1.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. import torch
  2. from ..utils import has_triton
  3. if has_triton():
  4. import triton
  5. class _conv1x1:
  6. @staticmethod
  7. def _call(
  8. x,
  9. w,
  10. bias,
  11. stride,
  12. padding,
  13. dilation,
  14. transposed,
  15. output_padding,
  16. groups,
  17. ):
  18. # Q: should we check x, w, bias dtypes?
  19. device = x.device
  20. # input shapes
  21. shape_x = x.shape
  22. shape_w = w.shape
  23. shape_bias = bias.shape if bias is not None else None
  24. # indicies for the layout
  25. xn, xc, xh, xw = 0, 1, 2, 3
  26. yn, yc, yh, yw = 0, 1, 2, 3
  27. wn, wc, wh, ww = 0, 1, 2, 3
  28. # out_channel, in_channel, kernel_height, kernel_width
  29. kernel_size = [shape_w[wh], shape_w[ww]]
  30. input_size = [shape_x[xh], shape_x[xw]]
  31. assert (
  32. not shape_bias or shape_bias[0] == shape_w[wn]
  33. ), f"bias shape did not match{shape_bias} != {shape_w[wn]}"
  34. in_channel = shape_w[wc] * groups
  35. assert shape_x[xc] % groups == 0, "in_channels must be divisible by groups"
  36. assert shape_w[wn] % groups == 0, "out_channels must be divisible by groups"
  37. assert (
  38. shape_x[xc] == in_channel
  39. ), f"in_channel did not match {shape_x[xc]} != {in_channel}"
  40. assert (
  41. len(stride)
  42. == len(padding)
  43. == len(dilation)
  44. == len(output_padding)
  45. == len(kernel_size)
  46. == len(input_size)
  47. )
  48. # output shape
  49. shape_y = [0] * 4
  50. shape_y[yn] = shape_x[xn]
  51. shape_y[yc] = shape_w[wn]
  52. shape_y[yh] = (
  53. input_size[0]
  54. + 2 * padding[0]
  55. - dilation[0] * (kernel_size[0] - 1)
  56. - 1
  57. + stride[0]
  58. ) // stride[0] + 2 * output_padding[0]
  59. shape_y[yw] = (
  60. input_size[1]
  61. + 2 * padding[1]
  62. - dilation[1] * (kernel_size[1] - 1)
  63. - 1
  64. + stride[1]
  65. ) // stride[1] + 2 * output_padding[1]
  66. BATCH = shape_x[xn]
  67. IN_C = shape_x[xc]
  68. # IN_H = shape_x[xh]
  69. # IN_W = shape_x[xw]
  70. KERNEL_N = shape_w[wn]
  71. KERNEL_H = shape_w[wh]
  72. KERNEL_W = shape_w[ww]
  73. OUT_H = shape_y[yh]
  74. OUT_W = shape_y[yw]
  75. assert KERNEL_H == 1 and KERNEL_W == 1, "only support 1x1 conv"
  76. channels_last = x.stride()[1] == 1
  77. if padding == (0, 0):
  78. # nchw -> nhwc
  79. x = x.permute(0, 2, 3, 1)
  80. # select every stride's element (for stride > 1)
  81. x = x[:, :: stride[0], :: stride[1], :]
  82. # 2d matrix
  83. mat_x = x.reshape(-1, IN_C)
  84. # 2d matrix
  85. mat_w = w.view(KERNEL_N, IN_C)
  86. mat_w = mat_w.permute(1, 0)
  87. # 2d matrix y, (BATCH * OUT_H * OUT_W, KERNEL_N)
  88. mat_y = triton.ops.matmul(mat_x, mat_w)
  89. # mat_y = torch.empty((BATCH * OUT_H * OUT_W, KERNEL_N), device=device, dtype=x.dtype,)
  90. y = mat_y.view(BATCH, OUT_H, OUT_W, KERNEL_N)
  91. if bias is not None:
  92. y += bias
  93. # convert back to the original layout of y
  94. # nhwc -> nchw
  95. y = y.permute(0, 3, 1, 2)
  96. if not channels_last:
  97. y = y.to(memory_format=torch.contiguous_format)
  98. return y
  99. else:
  100. y = torch.empty(
  101. (shape_y[yn], shape_y[yh], shape_y[yw], shape_y[yc]),
  102. device=device,
  103. dtype=x.dtype,
  104. )
  105. if channels_last:
  106. y = y.to(memory_format=torch.channels_last)
  107. # y = bias.repeat((shape_y[yn], shape_y[yh], shape_y[yw], 1)).to(device).type(x.dtype)
  108. # convert x to channel-last layout;
  109. # don't care w layout since kernel size is 1
  110. x = x.permute(0, 2, 3, 1)
  111. # select every stride"s element (for stride > 1)
  112. x = x[:, :: stride[0], :: stride[1], :]
  113. # 2d matrix
  114. mat_x = x.view(-1, IN_C)
  115. # 2d matrix
  116. mat_w = w.view(KERNEL_N, IN_C)
  117. mat_w = mat_w.permute(1, 0)
  118. # 2d matrix y, (BATCH * (OUT_H-2*padding) * (OUT_W-2*padding), KERNEL_N)
  119. mat_y = triton.ops.matmul(mat_x, mat_w)
  120. mat_y = mat_y.view(
  121. BATCH, OUT_H - 2 * padding[0], OUT_W - 2 * padding[1], KERNEL_N
  122. )
  123. # consider padding > 0
  124. if bias is not None:
  125. y[
  126. :,
  127. padding[0] : OUT_H - padding[0],
  128. padding[1] : OUT_W - padding[1],
  129. :,
  130. ] = (
  131. mat_y + bias
  132. )
  133. y[:, : padding[0], :, :] = bias
  134. y[:, :, : padding[1], :] = bias
  135. y[:, OUT_H - padding[0] :, :, :] = bias
  136. y[:, :, OUT_W - padding[1] :, :] = bias
  137. else:
  138. y[
  139. :,
  140. padding[0] : OUT_H - padding[0],
  141. padding[1] : OUT_W - padding[1],
  142. :,
  143. ] = mat_y
  144. y[:, : padding[0], :, :] = 0
  145. y[:, :, : padding[1], :] = 0
  146. y[:, OUT_H - padding[0] :, :, :] = 0
  147. y[:, :, OUT_W - padding[1] :, :] = 0
  148. # convert back to the original layout of y
  149. # nhwc -> nchw
  150. y = y.permute(0, 3, 1, 2)
  151. return y
  152. @staticmethod
  153. def forward(
  154. x,
  155. w,
  156. bias,
  157. stride=(1, 1),
  158. padding=(0, 0),
  159. dilation=(1, 1),
  160. transposed=False,
  161. output_padding=(0, 0),
  162. groups=1,
  163. ):
  164. if groups != 1:
  165. print(f"Do not support groups = {groups}")
  166. return
  167. if transposed:
  168. print("Do not support transposed")
  169. return _conv1x1._call(
  170. x,
  171. w,
  172. bias,
  173. stride,
  174. padding,
  175. dilation,
  176. transposed,
  177. output_padding,
  178. groups,
  179. )
  180. conv1x1 = _conv1x1.forward