mm.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import logging
  2. import torch
  3. from ..lowering import register_lowering
  4. from ..select_algorithm import (
  5. autotune_select_algorithm,
  6. ExternKernelChoice,
  7. TritonTemplate,
  8. )
  9. from ..utils import use_triton_template
  10. from .mm_common import addmm_epilogue, mm_args, mm_configs, mm_grid, mm_options
  11. log = logging.getLogger(__name__)
  12. aten = torch.ops.aten
  13. mm_template = TritonTemplate(
  14. name="mm",
  15. grid=mm_grid,
  16. source=r"""
  17. {{def_kernel("A", "B")}}
  18. M = {{size("A", 0)}}
  19. N = {{size("B", 1)}}
  20. K = {{size("A", 1)}}
  21. stride_am = {{stride("A", 0)}}
  22. stride_ak = {{stride("A", 1)}}
  23. stride_bk = {{stride("B", 0)}}
  24. stride_bn = {{stride("B", 1)}}
  25. # based on triton.ops.matmul
  26. pid = tl.program_id(0)
  27. grid_m = (M + BLOCK_M - 1) // BLOCK_M
  28. grid_n = (N + BLOCK_N - 1) // BLOCK_N
  29. # re-order program ID for better L2 performance
  30. width = GROUP_M * grid_n
  31. group_id = pid // width
  32. group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
  33. pid_m = group_id * GROUP_M + (pid % group_size)
  34. pid_n = (pid % width) // (group_size)
  35. rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
  36. rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
  37. ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
  38. rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
  39. rk = tl.arange(0, BLOCK_K)
  40. A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
  41. B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
  42. acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
  43. for k in range(K, 0, -BLOCK_K):
  44. if EVEN_K:
  45. a = tl.load(A)
  46. b = tl.load(B)
  47. else:
  48. a = tl.load(A, mask=rk[None, :] < k, other=0.)
  49. b = tl.load(B, mask=rk[:, None] < k, other=0.)
  50. acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
  51. A += BLOCK_K * stride_ak
  52. B += BLOCK_K * stride_bk
  53. # rematerialize rm and rn to save registers
  54. rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
  55. rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
  56. idx_m = rm[:, None]
  57. idx_n = rn[None, :]
  58. mask = (idx_m < M) & (idx_n < N)
  59. # inductor generates a suffix
  60. {{store_output(("idx_m", "idx_n"), "acc", "mask")}}
  61. """,
  62. )
  63. aten_mm = ExternKernelChoice(torch.mm, "at::mm_out")
  64. aten_addmm = ExternKernelChoice(torch.addmm, "at::addmm_out")
  65. def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1):
  66. """
  67. Giving torch.addmm a 1D tensor calls a different (faster) cublasLt
  68. kernel under the hood. There are a few shapes where this is slower,
  69. but they are rare.
  70. """
  71. if inp.stride(0) == 0 or inp.size(0) == 1:
  72. return torch.addmm(inp[0], mat1, mat2, out=out, alpha=alpha, beta=beta)
  73. return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta)
  74. aten_bias_addmm = ExternKernelChoice(bias_addmm, None)
  75. @register_lowering(aten.mm)
  76. def tuned_mm(mat1, mat2, *, layout=None):
  77. m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
  78. # options to tune from
  79. choices = [aten_mm.bind((mat1, mat2), layout)]
  80. if use_triton_template(layout):
  81. for config in mm_configs():
  82. choices.append(
  83. mm_template.generate(
  84. (mat1, mat2),
  85. layout,
  86. **mm_options(config, k, layout),
  87. )
  88. )
  89. return autotune_select_algorithm(choices, [mat1, mat2], layout)
  90. @register_lowering(aten.addmm)
  91. def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
  92. m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout)
  93. if not use_triton_template(layout):
  94. choices = [aten_addmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)]
  95. return autotune_select_algorithm(choices, [inp, mat1, mat2], layout)
  96. choices = [
  97. aten_addmm.bind((inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta)
  98. ]
  99. if inp_expanded.get_stride()[0] == 0 and inp_expanded.get_device().type == "cuda":
  100. # unexpand inp to make sure fused addmm from cublasLt is used
  101. choices.insert(
  102. 0,
  103. aten_bias_addmm.bind(
  104. (inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta
  105. ),
  106. )
  107. for config in mm_configs():
  108. choices.append(
  109. mm_template.generate(
  110. (inp_expanded, mat1, mat2),
  111. layout,
  112. **mm_options(config, k, layout),
  113. prefix_args=1,
  114. epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
  115. )
  116. )
  117. return autotune_select_algorithm(choices, [inp_expanded, mat1, mat2], layout)