mm_common.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import functools
  2. import logging
  3. import sympy
  4. import torch
  5. from torch._inductor.select_algorithm import realize_inputs
  6. from torch._inductor.virtualized import V
  7. from ..utils import ceildiv as cdiv
  8. log = logging.getLogger(__name__)
  9. @functools.lru_cache(None)
  10. def mm_configs():
  11. import triton
  12. return [
  13. triton.Config(
  14. {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=2, num_warps=4
  15. ),
  16. triton.Config(
  17. {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=3, num_warps=4
  18. ),
  19. triton.Config(
  20. {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=3, num_warps=4
  21. ),
  22. triton.Config(
  23. {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=8
  24. ),
  25. triton.Config(
  26. {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=4, num_warps=8
  27. ),
  28. triton.Config(
  29. {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32}, num_stages=5, num_warps=8
  30. ),
  31. triton.Config(
  32. {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32}, num_stages=5, num_warps=8
  33. ),
  34. triton.Config(
  35. {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32}, num_stages=2, num_warps=8
  36. ),
  37. triton.Config(
  38. {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=3, num_warps=8
  39. ),
  40. triton.Config(
  41. {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=2, num_warps=4
  42. ),
  43. triton.Config(
  44. {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16}, num_stages=2, num_warps=4
  45. ),
  46. triton.Config(
  47. {"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16}, num_stages=1, num_warps=2
  48. ),
  49. ]
  50. def mm_grid(m, n, meta):
  51. """
  52. The CUDA grid size for matmul triton templates.
  53. """
  54. return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1)
  55. def acc_type(dtype):
  56. if dtype in (torch.float16, torch.bfloat16):
  57. return "tl.float32"
  58. return f"tl.{dtype}".replace("torch.", "")
  59. def mm_options(config, sym_k, layout):
  60. """
  61. Common options to matmul triton templates.
  62. """
  63. even_k_symbolic = (
  64. # it isn't worth guarding on this
  65. sympy.gcd(sym_k, config.kwargs["BLOCK_K"])
  66. == config.kwargs["BLOCK_K"]
  67. )
  68. return dict(
  69. GROUP_M=8,
  70. EVEN_K=even_k_symbolic,
  71. ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32,
  72. ACC_TYPE=acc_type(layout.dtype),
  73. num_stages=config.num_stages,
  74. num_warps=config.num_warps,
  75. **config.kwargs,
  76. )
  77. def mm_args(mat1, mat2, *others, layout=None):
  78. """
  79. Common arg processing for mm,bmm,addmm,etc
  80. """
  81. mat1, mat2 = realize_inputs(mat1, mat2)
  82. *b1, m, k1 = mat1.get_size()
  83. *b2, k2, n = mat2.get_size()
  84. b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)]
  85. k = V.graph.sizevars.guard_equals(k1, k2)
  86. if layout is None:
  87. from torch._inductor.ir import FixedLayout
  88. layout = FixedLayout(
  89. mat1.get_device(),
  90. mat1.get_dtype(),
  91. [*b, m, n],
  92. )
  93. from ..lowering import expand
  94. others = [realize_inputs(expand(x, layout.size)) for x in others]
  95. return [m, n, k, layout, mat1, mat2, *others]
  96. def addmm_epilogue(dtype, alpha, beta):
  97. def epilogue(acc, bias):
  98. if alpha != 1:
  99. acc = V.ops.mul(acc, V.ops.constant(alpha, dtype))
  100. if beta != 1:
  101. bias = V.ops.mul(bias, V.ops.constant(beta, dtype))
  102. return V.ops.add(acc, bias)
  103. return epilogue