conv.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744
  1. import torch
  2. from ..utils import has_triton
  3. if has_triton():
  4. import triton
  5. import triton.language as tl
  6. from .autotune import conv_heuristics
  7. from .utils import _unpack
  8. @conv_heuristics()
  9. @triton.jit
  10. def _kernel_delta_x_hwc(
  11. x,
  12. w,
  13. y,
  14. # stride of tensor
  15. stride_xn,
  16. stride_xc,
  17. stride_xh,
  18. stride_xw,
  19. stride_wn,
  20. stride_wc,
  21. stride_wh,
  22. stride_ww,
  23. stride_yn,
  24. stride_yc,
  25. stride_yh,
  26. stride_yw,
  27. stride_biasn,
  28. # pointer inc for x
  29. delta_xh_ptr,
  30. delta_xw_ptr,
  31. delta_xc_ptr,
  32. # Tensor dimensions
  33. BATCH,
  34. IN_C,
  35. IN_H,
  36. IN_W,
  37. KERNEL_N,
  38. KERNEL_H,
  39. KERNEL_W,
  40. OUT_H,
  41. OUT_W,
  42. # parameters of conv
  43. stride_h,
  44. stride_w,
  45. padding_h,
  46. padding_w,
  47. dilation_h,
  48. dilation_w,
  49. output_padding_h,
  50. output_padding_w,
  51. groups,
  52. # Metaparameters
  53. ACC_TYPE: tl.constexpr,
  54. CONV1X1_NHWC: tl.constexpr,
  55. # blocks in different dimension
  56. BLOCK_M: tl.constexpr,
  57. BLOCK_N: tl.constexpr,
  58. # reduction tiling parameter for matmul
  59. BLOCK_K: tl.constexpr,
  60. # Super-blocking for better L2 peformance
  61. GROUP_H: tl.constexpr,
  62. ):
  63. """
  64. each program instance computes a [BLOCK_BATCH, BLOCK_N, BLOCK_H, BLOCK_W] block of y
  65. """
  66. # -----------------------------------------------------------
  67. # Map program ids `pid` to the block of y it should compute.
  68. pid_nhw = tl.program_id(0)
  69. pid_k = tl.program_id(1)
  70. # offset for output y
  71. off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
  72. off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
  73. off_y_n = off_y_nhw // (OUT_H * OUT_W)
  74. off_y_hw = off_y_nhw % (OUT_H * OUT_W)
  75. off_y_h = off_y_hw // OUT_W + output_padding_h
  76. off_y_w = off_y_hw % OUT_W + output_padding_w
  77. # offset for the initial ptr for x
  78. off_x_n = off_y_n
  79. off_x_h = off_y_h * stride_h - padding_h
  80. off_x_w = off_y_w * stride_w - padding_w
  81. off_x_nhw = off_x_n * stride_xn + off_x_h * stride_xh + off_x_w * stride_xw
  82. off_x_crs = tl.arange(0, BLOCK_K)
  83. CRS = IN_C * KERNEL_H * KERNEL_W
  84. # load inc ptr of x, upade x_ptrs
  85. if not CONV1X1_NHWC:
  86. delta_xh_ptrs = delta_xh_ptr + off_x_crs
  87. delta_xw_ptrs = delta_xw_ptr + off_x_crs
  88. delta_xc_ptrs = delta_xc_ptr + off_x_crs
  89. delta_xh = tl.load(delta_xh_ptrs, mask=off_x_crs < CRS, other=0)
  90. delta_xw = tl.load(delta_xw_ptrs, mask=off_x_crs < CRS, other=0)
  91. delta_xc = tl.load(delta_xc_ptrs, mask=off_x_crs < CRS, other=0)
  92. off_x_crs_unpacked = (
  93. delta_xh * stride_xh + delta_xw * stride_xw + delta_xc * stride_xc
  94. )
  95. x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
  96. else:
  97. x_ptrs = x + off_x_nhw[:, None] + off_x_crs[None, :]
  98. delta_xh = 0
  99. delta_xw = 0
  100. mask_x = (
  101. (off_x_n < BATCH)[:, None]
  102. & (off_x_crs < CRS)[None, :]
  103. & (off_x_h[:, None] + delta_xh[None, :] >= 0)
  104. & (off_x_h[:, None] + delta_xh[None, :] < IN_H)
  105. & (off_x_w[:, None] + delta_xw[None, :] >= 0)
  106. & (off_x_w[:, None] + delta_xw[None, :] < IN_W)
  107. )
  108. # offset for the inital ptr for w
  109. off_w_crs = tl.arange(0, BLOCK_K)
  110. off_w_k = off_y_k
  111. w_ptrs = w + off_w_crs[:, None] + off_w_k[None, :] * stride_wn
  112. mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
  113. # ------ load x ------
  114. matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
  115. # ------ load w ------
  116. matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
  117. # -----------------------------------------------------------
  118. # allocate accumulator
  119. acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
  120. for crs in range(0, CRS, BLOCK_K):
  121. # ------ matrix multiplication ------
  122. acc += tl.dot(matrix_x, matrix_w)
  123. # ------ update ptrs ------
  124. w_ptrs += BLOCK_K
  125. # load inc ptr of x, upade x_ptrs
  126. off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K)
  127. if not CONV1X1_NHWC:
  128. delta_xh_ptrs += BLOCK_K
  129. delta_xw_ptrs += BLOCK_K
  130. delta_xc_ptrs += BLOCK_K
  131. delta_xh = tl.load(delta_xh_ptrs, mask=off_x_crs < CRS, other=0)
  132. delta_xw = tl.load(delta_xw_ptrs, mask=off_x_crs < CRS, other=0)
  133. delta_xc = tl.load(delta_xc_ptrs, mask=off_x_crs < CRS, other=0)
  134. off_x_crs_unpacked = (
  135. delta_xh * stride_xh + delta_xw * stride_xw + delta_xc * stride_xc
  136. )
  137. x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
  138. else:
  139. x_ptrs += BLOCK_K
  140. mask_x = (
  141. (off_x_n < BATCH)[:, None]
  142. & (off_x_crs < CRS)[None, :]
  143. & (off_x_h[:, None] + delta_xh[None, :] >= 0)
  144. & (off_x_h[:, None] + delta_xh[None, :] < IN_H)
  145. & (off_x_w[:, None] + delta_xw[None, :] >= 0)
  146. & (off_x_w[:, None] + delta_xw[None, :] < IN_W)
  147. )
  148. mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
  149. # ------ prefetch ------
  150. # ------ load x ------
  151. matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
  152. # ------ load w ------
  153. matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
  154. acc = acc.to(y.dtype.element_ty)
  155. # rematerialize -- this saves some registers
  156. # offset for output y
  157. off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
  158. off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
  159. off_y_n = off_y_nhw // (OUT_H * OUT_W)
  160. off_y_hw = off_y_nhw % (OUT_H * OUT_W)
  161. # consider output padding
  162. off_y_h = off_y_hw // OUT_W + output_padding_h
  163. off_y_w = off_y_hw % OUT_W + output_padding_w
  164. # y ptrs in the block of [BLOCK_M, BLOCK_N]
  165. y_ptrs = (
  166. y
  167. + off_y_n[:, None] * stride_yn
  168. + off_y_h[:, None] * stride_yh
  169. + off_y_w[:, None] * stride_yw
  170. + off_y_k[None, :] * stride_yc
  171. )
  172. # out-of-bounds check
  173. mask_y = (
  174. (off_y_n < BATCH)[:, None]
  175. & (off_y_h < OUT_H + output_padding_h)[:, None]
  176. & (off_y_w < OUT_W + output_padding_w)[:, None]
  177. & (off_y_k < KERNEL_N)[None, :]
  178. )
  179. tl.store(y_ptrs, acc, mask=mask_y)
  180. return
  181. @conv_heuristics()
  182. @triton.jit
  183. def _kernel_delta_x(
  184. x,
  185. w,
  186. y,
  187. # stride of tensor
  188. stride_xn,
  189. stride_xc,
  190. stride_xh,
  191. stride_xw,
  192. stride_wn,
  193. stride_wc,
  194. stride_wh,
  195. stride_ww,
  196. stride_yn,
  197. stride_yc,
  198. stride_yh,
  199. stride_yw,
  200. stride_biasn,
  201. # pointer inc for x
  202. delta_x_ptr,
  203. # Tensor dimensions
  204. BATCH,
  205. IN_C,
  206. IN_H,
  207. IN_W,
  208. KERNEL_N,
  209. KERNEL_H,
  210. KERNEL_W,
  211. OUT_H,
  212. OUT_W,
  213. # parameters of conv
  214. stride_h,
  215. stride_w,
  216. padding_h,
  217. padding_w,
  218. dilation_h,
  219. dilation_w,
  220. output_padding_h,
  221. output_padding_w,
  222. groups,
  223. # Metaparameters
  224. ACC_TYPE: tl.constexpr,
  225. CONV1X1_NHWC: tl.constexpr,
  226. # blocks in different dimension
  227. BLOCK_M: tl.constexpr,
  228. BLOCK_N: tl.constexpr,
  229. # reduction tiling parameter for matmul
  230. BLOCK_K: tl.constexpr,
  231. # Super-blocking for better L2 peformance
  232. GROUP_H: tl.constexpr,
  233. ):
  234. """
  235. each program instance computes a [BLOCK_BATCH, BLOCK_N, BLOCK_H, BLOCK_W] block of y
  236. """
  237. # -----------------------------------------------------------
  238. # Map program ids `pid` to the block of y it should compute.
  239. pid_nhw = tl.program_id(0)
  240. pid_k = tl.program_id(1)
  241. # offset for output y
  242. off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
  243. off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
  244. off_y_n = off_y_nhw // (OUT_H * OUT_W)
  245. off_y_hw = off_y_nhw % (OUT_H * OUT_W)
  246. off_y_h = off_y_hw // OUT_W + output_padding_h
  247. off_y_w = off_y_hw % OUT_W + output_padding_w
  248. # offset for the initial ptr for x
  249. off_x_n = off_y_n
  250. off_x_h = off_y_h * stride_h - padding_h
  251. off_x_w = off_y_w * stride_w - padding_w
  252. off_x_nhw = off_x_n * stride_xn + off_x_h * stride_xh + off_x_w * stride_xw
  253. off_x_crs = tl.arange(0, BLOCK_K)
  254. CRS = IN_C * KERNEL_H * KERNEL_W
  255. # load inc ptr of x, upade x_ptrs
  256. if not CONV1X1_NHWC:
  257. delta_x_ptrs = delta_x_ptr + off_x_crs
  258. off_x_crs_unpacked = tl.load(delta_x_ptrs, mask=off_x_crs < CRS)
  259. x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
  260. else:
  261. x_ptrs = x + off_x_nhw[:, None] + off_x_crs[None, :]
  262. mask_x = (
  263. (off_x_n < BATCH)
  264. & (off_x_h >= 0)
  265. & (off_x_h < IN_H)
  266. & (off_x_w >= 0)
  267. & (off_x_w < IN_W)
  268. )[:, None] & (off_x_crs < CRS)[None, :]
  269. # offset for the inital ptr for w
  270. off_w_crs = tl.arange(0, BLOCK_K)
  271. off_w_k = off_y_k
  272. w_ptrs = w + off_w_crs[:, None] + off_w_k[None, :] * stride_wn
  273. mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
  274. # ------ load x ------
  275. matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
  276. # ------ load w ------
  277. matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
  278. # -----------------------------------------------------------
  279. # allocate accumulator
  280. acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
  281. for crs in range(0, CRS, BLOCK_K):
  282. # ------ matrix multiplication ------
  283. acc += tl.dot(matrix_x, matrix_w)
  284. # ------ update ptrs ------
  285. w_ptrs += BLOCK_K
  286. # load inc ptr of x, upade x_ptrs
  287. if not CONV1X1_NHWC:
  288. delta_x_ptrs += BLOCK_K
  289. off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K)
  290. off_x_crs_unpacked = tl.load(
  291. delta_x_ptrs, mask=off_x_crs < CRS, other=0
  292. )
  293. x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]
  294. else:
  295. off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K)
  296. x_ptrs += BLOCK_K
  297. mask_x = (
  298. (off_x_n < BATCH)
  299. & (off_x_h >= 0)
  300. & (off_x_h < IN_H)
  301. & (off_x_w >= 0)
  302. & (off_x_w < IN_W)
  303. )[:, None] & (off_x_crs < CRS)[None, :]
  304. mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
  305. # ------ prefetch ------
  306. # ------ load x ------
  307. matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
  308. # ------ load w ------
  309. matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
  310. acc = acc.to(y.dtype.element_ty)
  311. # rematerialize -- this saves some registers
  312. # offset for output y
  313. off_y_k = pid_k * BLOCK_N + tl.arange(0, BLOCK_N)
  314. off_y_nhw = pid_nhw * BLOCK_M + tl.arange(0, BLOCK_M)
  315. off_y_n = off_y_nhw // (OUT_H * OUT_W)
  316. off_y_hw = off_y_nhw % (OUT_H * OUT_W)
  317. # consider output padding
  318. off_y_h = off_y_hw // OUT_W + output_padding_h
  319. off_y_w = off_y_hw % OUT_W + output_padding_w
  320. # y ptrs in the block of [BLOCK_M, BLOCK_N]
  321. y_ptrs = (
  322. y
  323. + off_y_n[:, None] * stride_yn
  324. + off_y_h[:, None] * stride_yh
  325. + off_y_w[:, None] * stride_yw
  326. + off_y_k[None, :] * stride_yc
  327. )
  328. # out-of-bounds check
  329. mask_y = (
  330. (off_y_n < BATCH)[:, None]
  331. & (off_y_h < OUT_H + output_padding_h)[:, None]
  332. & (off_y_w < OUT_W + output_padding_w)[:, None]
  333. & (off_y_k < KERNEL_N)[None, :]
  334. )
  335. tl.store(y_ptrs, acc, mask=mask_y)
  336. return
  337. class _conv:
  338. kernel = _kernel_delta_x_hwc
  339. # for the contigous order of w ptr, what"s the corresponding
  340. # ptr changes for x in a sliding window
  341. @staticmethod
  342. def _delta_x_ptr_hwc(
  343. IN_C,
  344. KERNEL_H,
  345. KERNEL_W,
  346. dilation_h,
  347. dilation_w,
  348. stride_wc,
  349. stride_wh,
  350. stride_ww,
  351. stride_xc,
  352. stride_xh,
  353. stride_xw,
  354. device,
  355. ):
  356. # get the order of axes in w, innermost dimension outward
  357. stride_w_3d = [stride_wc, stride_wh, stride_ww]
  358. order = sorted(range(len(stride_w_3d)), key=stride_w_3d.__getitem__)
  359. window_size = IN_C * KERNEL_H * KERNEL_W
  360. r_window = torch.arange(0, window_size, 1, device=device)
  361. window_unpack = _unpack(r_window, order, [IN_C, KERNEL_H, KERNEL_W])
  362. window_unpack_c = window_unpack[order[0]]
  363. window_unpack_h = window_unpack[order[1]]
  364. window_unpack_w = window_unpack[order[2]]
  365. r_dilation_h = dilation_h * window_unpack_h
  366. r_dilation_w = dilation_w * window_unpack_w
  367. r_inc = window_unpack_c
  368. # delta_x = (
  369. # r_dilation_h * stride_xh + r_dilation_w * stride_xw + r_inc * stride_xc
  370. # )
  371. # return delta_x
  372. return (
  373. r_dilation_h,
  374. r_dilation_w,
  375. r_inc,
  376. )
  377. @staticmethod
  378. def _delta_x_ptr(
  379. IN_C,
  380. KERNEL_H,
  381. KERNEL_W,
  382. dilation_h,
  383. dilation_w,
  384. stride_wc,
  385. stride_wh,
  386. stride_ww,
  387. stride_xc,
  388. stride_xh,
  389. stride_xw,
  390. device,
  391. ):
  392. # get the order of axes in w, innermost dimension outward
  393. stride_w_3d = [stride_wc, stride_wh, stride_ww]
  394. order = sorted(range(len(stride_w_3d)), key=stride_w_3d.__getitem__)
  395. window_size = IN_C * KERNEL_H * KERNEL_W
  396. r_window = torch.arange(0, window_size, 1, device=device)
  397. window_unpack = _unpack(r_window, order, [IN_C, KERNEL_H, KERNEL_W])
  398. window_unpack_c = window_unpack[order[0]]
  399. window_unpack_h = window_unpack[order[1]]
  400. window_unpack_w = window_unpack[order[2]]
  401. r_dilation_h = dilation_h * window_unpack_h
  402. r_dilation_w = dilation_w * window_unpack_w
  403. r_inc = window_unpack_c
  404. delta_x = (
  405. r_dilation_h * stride_xh + r_dilation_w * stride_xw + r_inc * stride_xc
  406. )
  407. return delta_x
  408. @staticmethod
  409. def _call(
  410. x,
  411. w,
  412. bias,
  413. stride,
  414. padding,
  415. dilation,
  416. transposed,
  417. output_padding,
  418. groups,
  419. ):
  420. # Q: should we check x, w, bias dtypes?
  421. device = x.device
  422. # input shapes
  423. shape_x = x.shape
  424. shape_w = w.shape
  425. shape_bias = bias.shape if bias is not None else None
  426. # indicies for the layout
  427. xn, xc, xh, xw = 0, 1, 2, 3
  428. yn, yc, yh, yw = 0, 1, 2, 3
  429. wn, wc, wh, ww = 0, 1, 2, 3
  430. # out_channel, in_channel, kernel_height, kernel_width
  431. kernel_size = [shape_w[wh], shape_w[ww]]
  432. input_size = [shape_x[xh], shape_x[xw]]
  433. assert (
  434. not shape_bias or shape_bias[0] == shape_w[wn]
  435. ), f"bias shape did not match{shape_bias} != {shape_w[wn]}"
  436. in_channel = shape_w[wc] * groups
  437. assert shape_x[xc] % groups == 0, "in_channels must be divisible by groups"
  438. assert shape_w[wn] % groups == 0, "out_channels must be divisible by groups"
  439. assert (
  440. shape_x[xc] == in_channel
  441. ), f"in_channel did not match {shape_x[xc]} != {in_channel}"
  442. assert (
  443. len(stride)
  444. == len(padding)
  445. == len(dilation)
  446. == len(output_padding)
  447. == len(kernel_size)
  448. == len(input_size)
  449. )
  450. # output shape
  451. shape_y = [0] * 4
  452. shape_y[yn] = shape_x[xn]
  453. shape_y[yc] = shape_w[wn]
  454. shape_y[yh] = (
  455. input_size[0]
  456. + 2 * padding[0]
  457. - dilation[0] * (kernel_size[0] - 1)
  458. - 1
  459. + stride[0]
  460. ) // stride[0] + 2 * output_padding[0]
  461. shape_y[yw] = (
  462. input_size[1]
  463. + 2 * padding[1]
  464. - dilation[1] * (kernel_size[1] - 1)
  465. - 1
  466. + stride[1]
  467. ) // stride[1] + 2 * output_padding[1]
  468. BATCH = shape_x[xn]
  469. IN_C = shape_x[xc]
  470. IN_H = shape_x[xh]
  471. IN_W = shape_x[xw]
  472. KERNEL_N = shape_w[wn]
  473. KERNEL_H = shape_w[wh]
  474. KERNEL_W = shape_w[ww]
  475. OUT_H = shape_y[yh]
  476. OUT_W = shape_y[yw]
  477. # allocate output
  478. y = torch.empty(shape_y, device=device, dtype=x.dtype)
  479. # get strides for tensors
  480. stride_x = x.stride()
  481. stride_w = w.stride()
  482. stride_bias = bias.stride() if shape_bias else None
  483. stride_biasn = stride_bias[0] if stride_bias else None
  484. # output layout should be the same as x
  485. if stride_x[xc] < stride_x[xh] and stride_x[xc] < stride_x[xw]:
  486. y = y.to(memory_format=torch.channels_last)
  487. stride_y = y.stride()
  488. # allocate tmp
  489. # WINDOW_SIZE = KERNEL_H * KERNEL_W * IN_C
  490. # tmp_x = torch.empty((BATCH * OUT_H * OUT_W, WINDOW_SIZE), device=device, dtype=x.dtype)
  491. # tmp_w = torch.empty((WINDOW_SIZE, KERNEL_N), device=device, dtype=w.dtype)
  492. # accumulator types
  493. ACC_TYPE = (
  494. tl.float32
  495. if x.dtype in [torch.float16, torch.bfloat16, torch.float32]
  496. else tl.int32
  497. )
  498. # if stride_x[xc] == 1 and stride_x > 1 and stride_y > 1:
  499. CONV1X1_NHWC = False
  500. if stride_x[xc] == 1 and KERNEL_H == 1 and KERNEL_W == 1:
  501. CONV1X1_NHWC = True
  502. # do we need delta x ptr for h, w, c dimension each or not
  503. DELTA_X_PTR_HWC = (
  504. False
  505. if (
  506. (padding[0] == 0 and padding[1] == 0)
  507. or (KERNEL_H == 1 and KERNEL_W == 1)
  508. )
  509. else True
  510. )
  511. if not CONV1X1_NHWC:
  512. if DELTA_X_PTR_HWC:
  513. delta_xh, delta_xw, delta_xc = _conv._delta_x_ptr_hwc(
  514. IN_C,
  515. KERNEL_H,
  516. KERNEL_W,
  517. dilation[0],
  518. dilation[1],
  519. stride_w[wc],
  520. stride_w[wh],
  521. stride_w[ww],
  522. stride_x[xc],
  523. stride_x[xh],
  524. stride_x[xw],
  525. device,
  526. )
  527. else:
  528. delta_x = _conv._delta_x_ptr(
  529. IN_C,
  530. KERNEL_H,
  531. KERNEL_W,
  532. dilation[0],
  533. dilation[1],
  534. stride_w[wc],
  535. stride_w[wh],
  536. stride_w[ww],
  537. stride_x[xc],
  538. stride_x[xh],
  539. stride_x[xw],
  540. device,
  541. )
  542. else:
  543. delta_x = None
  544. delta_xh, delta_xw, delta_xc = None, None, None
  545. # launch kernel, 2-dim, batch*h*w, kernel
  546. def grid(META):
  547. return (
  548. triton.cdiv(BATCH * OUT_H * OUT_W, META["BLOCK_M"]),
  549. triton.cdiv(KERNEL_N, META["BLOCK_N"]),
  550. )
  551. # conv1x1 or padding==0
  552. if CONV1X1_NHWC or not DELTA_X_PTR_HWC:
  553. _kernel_delta_x[grid](
  554. x,
  555. w,
  556. y,
  557. # stride nchw for x,w,y tensor
  558. stride_x[xn],
  559. stride_x[xc],
  560. stride_x[xh],
  561. stride_x[xw],
  562. stride_w[wn],
  563. stride_w[wc],
  564. stride_w[wh],
  565. stride_w[ww],
  566. stride_y[yn],
  567. stride_y[yc],
  568. stride_y[yh],
  569. stride_y[yw],
  570. stride_biasn,
  571. # pointer inc for x
  572. delta_x,
  573. # Tensor dimensions
  574. BATCH,
  575. IN_C,
  576. IN_H,
  577. IN_W,
  578. KERNEL_N,
  579. KERNEL_H,
  580. KERNEL_W,
  581. OUT_H,
  582. OUT_W,
  583. # conv parameters
  584. stride[0],
  585. stride[1],
  586. padding[0],
  587. padding[1],
  588. dilation[0],
  589. dilation[1],
  590. output_padding[0],
  591. output_padding[1],
  592. groups,
  593. # Metaparameters
  594. ACC_TYPE=ACC_TYPE,
  595. CONV1X1_NHWC=CONV1X1_NHWC,
  596. # BLOCK_M=128,
  597. # BLOCK_N=32,
  598. # BLOCK_K=32,
  599. GROUP_H=1,
  600. )
  601. # need to know ptr update for each dimension to check if
  602. # the sliding window is out of bounds
  603. else:
  604. # kernel = _kernel_delta_x_hwc
  605. _kernel_delta_x_hwc[grid](
  606. x,
  607. w,
  608. y,
  609. # stride nchw for x,w,y tensor
  610. stride_x[xn],
  611. stride_x[xc],
  612. stride_x[xh],
  613. stride_x[xw],
  614. stride_w[wn],
  615. stride_w[wc],
  616. stride_w[wh],
  617. stride_w[ww],
  618. stride_y[yn],
  619. stride_y[yc],
  620. stride_y[yh],
  621. stride_y[yw],
  622. stride_biasn,
  623. # pointer inc for x
  624. delta_xh,
  625. delta_xw,
  626. delta_xc,
  627. # Tensor dimensions
  628. BATCH,
  629. IN_C,
  630. IN_H,
  631. IN_W,
  632. KERNEL_N,
  633. KERNEL_H,
  634. KERNEL_W,
  635. OUT_H,
  636. OUT_W,
  637. # conv parameters
  638. stride[0],
  639. stride[1],
  640. padding[0],
  641. padding[1],
  642. dilation[0],
  643. dilation[1],
  644. output_padding[0],
  645. output_padding[1],
  646. groups,
  647. # Metaparameters
  648. ACC_TYPE=ACC_TYPE,
  649. CONV1X1_NHWC=CONV1X1_NHWC,
  650. # BLOCK_M=128,
  651. # BLOCK_N=32,
  652. # BLOCK_K=32,
  653. GROUP_H=1,
  654. )
  655. if bias is not None:
  656. if len(bias.shape) == 1:
  657. bias = bias.reshape([1, bias.shape[0], 1, 1])
  658. y += bias
  659. return y
  660. @staticmethod
  661. def forward(
  662. x,
  663. w,
  664. bias,
  665. stride=(1, 1),
  666. padding=(0, 0),
  667. dilation=(1, 1),
  668. transposed=False,
  669. output_padding=(0, 0),
  670. groups=1,
  671. ):
  672. if groups != 1:
  673. print(f"Do not support groups = {groups}")
  674. return
  675. if transposed:
  676. print("Do not support transposed")
  677. return _conv._call(
  678. x,
  679. w,
  680. bias,
  681. stride,
  682. padding,
  683. dilation,
  684. transposed,
  685. output_padding,
  686. groups,
  687. )
  688. conv = _conv.forward