autotuner.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import builtins
  2. import torch
  3. from torch._dynamo.testing import rand_strided
  4. from .. import config, triton_ops
  5. from ..virtualized import V
  6. aten = torch.ops.aten
  7. def str2func(str):
  8. module, *name = str.split(".")
  9. if module == "aten":
  10. runnable = aten
  11. elif module == "triton_ops":
  12. runnable = triton_ops
  13. elif module == "torch":
  14. runnable = torch
  15. else:
  16. raise Exception(f"{str} could not be called")
  17. for n in name:
  18. runnable = getattr(runnable, n)
  19. return runnable
  20. class Autotuner:
  21. def __init__(self):
  22. self.cache = dict()
  23. def _bench(self, kernel, *args, **kwargs):
  24. def kernel_call():
  25. kernel(*args, **kwargs)
  26. from triton.testing import do_bench
  27. return do_bench(kernel_call)
  28. autotune = Autotuner()
  29. def tuned_conv(
  30. x_shape,
  31. w_shape,
  32. x_stride,
  33. w_stride,
  34. stride,
  35. padding,
  36. dilation,
  37. transposed,
  38. output_padding,
  39. groups,
  40. device,
  41. dtype,
  42. adjust_triton=0.95,
  43. ):
  44. """
  45. Return the best kernel name given inputs and layer parameters;
  46. Considering potential pointwise fusion of conv, we could adjust triton timing
  47. by multiplying adjust_triton (default=0.95)
  48. """
  49. sizevars = V.graph.sizevars
  50. x_shape = [sizevars.size_hint(s) for s in x_shape]
  51. w_shape = [sizevars.size_hint(s) for s in w_shape]
  52. x_stride = [sizevars.size_hint(s) for s in x_stride]
  53. w_stride = [sizevars.size_hint(s) for s in w_stride]
  54. x = rand_strided(x_shape, x_stride, device=device, dtype=dtype)
  55. w = rand_strided(w_shape, w_stride, device=device, dtype=dtype)
  56. # the identifiable args for the layers
  57. id_args = [
  58. *x_shape,
  59. *w_shape,
  60. stride,
  61. padding,
  62. dilation,
  63. transposed,
  64. output_padding,
  65. groups,
  66. # *x_stride,
  67. # *w_stride,
  68. ]
  69. use_cuda = x.is_cuda
  70. # gen_key
  71. key = tuple(id_args)
  72. key = ("conv",) + key
  73. # candidate kernels
  74. kernels = ["aten.convolution"]
  75. if use_cuda:
  76. kernels += ["triton_ops.conv"]
  77. # filter kernels that args/kwargs does not meet requirements
  78. remove_kernels = []
  79. if groups > 1 or transposed:
  80. remove_kernels += ["triton_ops.conv"]
  81. kernels = [k for k in kernels if k not in remove_kernels]
  82. # if only one choice, return that kernel
  83. if len(kernels) == 1:
  84. kernel = kernels[0]
  85. # return kernel(
  86. # x, w, stride, padding, dilation, transposed, output_padding, groups
  87. # )
  88. return kernel
  89. timings = {}
  90. if key not in autotune.cache:
  91. for kernel in kernels:
  92. runnable_kernel = str2func(kernel)
  93. if "triton_ops" in kernel:
  94. # because we use nhwc layout by default for triton conv
  95. x = x.to(memory_format=torch.channels_last)
  96. run_args = (
  97. x,
  98. w,
  99. None,
  100. stride,
  101. padding,
  102. dilation,
  103. transposed,
  104. output_padding,
  105. groups,
  106. )
  107. timing, _, _ = autotune._bench(runnable_kernel, *run_args)
  108. if "triton_ops" in kernel:
  109. timing = timing * adjust_triton
  110. timings[kernel] = timing
  111. autotune.cache[key] = builtins.min(timings, key=timings.get)
  112. if config.debug:
  113. print("for key = ", key)
  114. print("timing", timings)
  115. print("best_kernel", autotune.cache[key])
  116. best_kernel = autotune.cache[key]
  117. # if best_kernel == "triton_ops.conv":
  118. # print(key, best_kernel)
  119. return best_kernel
  120. def tuned_conv_layout(
  121. kernel,
  122. x_shape,
  123. w_shape,
  124. stride,
  125. padding,
  126. dilation,
  127. transposed,
  128. output_padding,
  129. groups,
  130. device,
  131. dtype,
  132. ):
  133. sizevars = V.graph.sizevars
  134. x_shape = [sizevars.size_hint(s) for s in x_shape]
  135. w_shape = [sizevars.size_hint(s) for s in w_shape]
  136. x = torch.randn(x_shape, device=device, dtype=dtype)
  137. w = torch.randn(w_shape, device=device, dtype=dtype)
  138. id_args = [
  139. *x_shape,
  140. *w_shape,
  141. stride,
  142. padding,
  143. dilation,
  144. transposed,
  145. output_padding,
  146. groups,
  147. ]
  148. # gen_key
  149. key = tuple(id_args)
  150. key = ("conv_layout",) + key
  151. runnable_kernel = str2func(kernel)
  152. timings = {}
  153. if key not in autotune.cache:
  154. for memory_format in ["torch.contiguous_format", "torch.channels_last"]:
  155. x = x.to(memory_format=str2func(memory_format))
  156. run_args = (
  157. x,
  158. w,
  159. None,
  160. stride,
  161. padding,
  162. dilation,
  163. transposed,
  164. output_padding,
  165. groups,
  166. )
  167. timing, _, _ = autotune._bench(runnable_kernel, *run_args)
  168. timings[memory_format] = timing
  169. autotune.cache[key] = builtins.min(timings, key=timings.get)
  170. if config.debug:
  171. print("for key = ", key)
  172. print("timing", timings)
  173. print("best_layout", autotune.cache[key])
  174. best_layout = autotune.cache[key]
  175. return best_layout