bmm.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import torch
  2. from ..lowering import register_lowering
  3. from ..select_algorithm import (
  4. autotune_select_algorithm,
  5. ExternKernelChoice,
  6. TritonTemplate,
  7. )
  8. from ..utils import ceildiv as cdiv, use_triton_template
  9. from .mm_common import addmm_epilogue, mm_args, mm_configs, mm_options
  10. aten = torch.ops.aten
  11. def bmm_grid(b, m, n, meta):
  12. return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1)
  13. bmm_template = TritonTemplate(
  14. name="bmm",
  15. grid=bmm_grid,
  16. source=r"""
  17. {{def_kernel("A", "B")}}
  18. M = {{size("A", -2)}}
  19. N = {{size("B", -1)}}
  20. K = {{size("A", -1)}}
  21. stride_aq = {{stride("A", 0)}}
  22. stride_am = {{stride("A", 1)}}
  23. stride_ak = {{stride("A", 2)}}
  24. stride_bq = {{stride("B", 0)}}
  25. stride_bk = {{stride("B", 1)}}
  26. stride_bn = {{stride("B", 2)}}
  27. # based on triton.ops.matmul
  28. pid = tl.program_id(0)
  29. grid_m = (M + BLOCK_M - 1) // BLOCK_M
  30. grid_n = (N + BLOCK_N - 1) // BLOCK_N
  31. # re-order program ID for better L2 performance
  32. width = GROUP_M * grid_n
  33. group_id = pid // width
  34. group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
  35. pid_m = group_id * GROUP_M + (pid % group_size)
  36. pid_n = (pid % width) // (group_size)
  37. rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
  38. rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
  39. ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
  40. rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
  41. rk = tl.arange(0, BLOCK_K)
  42. idx_q = tl.program_id(1) # batch dimension for BMM
  43. A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq)
  44. B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq)
  45. acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
  46. for k in range(K, 0, -BLOCK_K):
  47. if EVEN_K:
  48. a = tl.load(A)
  49. b = tl.load(B)
  50. else:
  51. a = tl.load(A, mask=rk[None, :] < k, other=0.)
  52. b = tl.load(B, mask=rk[:, None] < k, other=0.)
  53. acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
  54. A += BLOCK_K * stride_ak
  55. B += BLOCK_K * stride_bk
  56. # rematerialize rm and rn to save registers
  57. rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
  58. rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
  59. idx_q = tl.program_id(1) # batch dimension for BMM
  60. idx_m = rm[:, None]
  61. idx_n = rn[None, :]
  62. mask = (idx_m < M) & (idx_n < N)
  63. # inductor generates a suffix
  64. {{store_output(("idx_q", "idx_m", "idx_n"), "acc", "mask")}}
  65. """,
  66. )
  67. aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out")
  68. aten_baddbmm = ExternKernelChoice(torch.baddbmm, "at::baddbmm_out")
  69. @register_lowering(aten.bmm)
  70. def tuned_bmm(mat1, mat2, *, layout=None):
  71. m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
  72. # options to tune from
  73. choices = [aten_bmm.bind((mat1, mat2), layout)]
  74. if use_triton_template(layout):
  75. for config in mm_configs():
  76. choices.append(
  77. bmm_template.generate(
  78. (mat1, mat2),
  79. layout,
  80. **mm_options(config, k, layout),
  81. )
  82. )
  83. return autotune_select_algorithm(choices, [mat1, mat2], layout)
  84. # Don't register this since it is slower than decomposing it
  85. # @register_lowering(aten.baddbmm)
  86. def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
  87. m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)
  88. # options to tune from
  89. choices = [aten_baddbmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)]
  90. if use_triton_template(layout):
  91. for config in mm_configs():
  92. choices.append(
  93. bmm_template.generate(
  94. (inp, mat1, mat2),
  95. layout,
  96. **mm_options(config, k, layout),
  97. prefix_args=1,
  98. epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
  99. )
  100. )
  101. return autotune_select_algorithm(choices, [inp, mat1, mat2], layout)