virtualized.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import itertools
  2. from contextlib import contextmanager
  3. from itertools import chain
  4. from threading import local
  5. import sympy
  6. from torch._inductor.utils import IndentedBuffer
  7. from torch.fx.graph import inplace_methods, magic_methods
  8. from .utils import sympy_str, sympy_symbol
  9. threadlocal = local()
  10. class Virtualized:
  11. """
  12. A global variable that redirects via thread local variable
  13. This allows us to swap in different op implementations in codegen.
  14. """
  15. def __init__(self, vname, default):
  16. self._key = f"__torchinductor_{vname}"
  17. self._default = default
  18. def _set_handler(self, value):
  19. prior = self._get_handler()
  20. setattr(threadlocal, self._key, value)
  21. @contextmanager
  22. def ctx():
  23. try:
  24. yield
  25. finally:
  26. self._set_handler(prior)
  27. return ctx()
  28. def _get_handler(self):
  29. try:
  30. return getattr(threadlocal, self._key)
  31. except AttributeError:
  32. return self._default()
  33. def __getattr__(self, name):
  34. return getattr(self._get_handler(), name)
  35. class NullHandler:
  36. pass
  37. def _arg_str(a):
  38. if isinstance(a, sympy.Expr):
  39. return sympy_str(a)
  40. return str(a)
  41. class MockHandler:
  42. def __getattr__(self, name):
  43. if name == "name":
  44. return "MockHandler"
  45. def inner(*args, **kwargs):
  46. fargs = [_arg_str(a) for a in args]
  47. fargs.extend(f"{k}={v}" for k, v in kwargs.items())
  48. return f"{name}({', '.join(fargs)})"
  49. return inner
  50. @staticmethod
  51. def masked(mask, body, other):
  52. return f"masked({mask}, {body()}, {other})"
  53. @staticmethod
  54. def indirect_indexing(index_var):
  55. return sympy_symbol(f"({str(index_var)})")
  56. @classmethod
  57. def _init_cls(cls):
  58. def make_handler(format_string):
  59. @staticmethod
  60. def inner(*args):
  61. return format_string.format(*args)
  62. return inner
  63. for name, format_string in chain(
  64. magic_methods.items(), inplace_methods.items()
  65. ):
  66. setattr(cls, name, make_handler(format_string))
  67. class KernelFormatterHandler:
  68. def __init__(self, parent_handler):
  69. self.parent_handler = parent_handler
  70. self.output = IndentedBuffer()
  71. self.var_counter = itertools.count()
  72. def __getattr__(self, name):
  73. def inner(*args, **kwargs):
  74. line = getattr(self.parent_handler, name)(*args, **kwargs)
  75. if name == "indirect_indexing":
  76. return line
  77. # replace line with a new variable name
  78. varname = f"tmp{next(self.var_counter)}"
  79. self.output.writeline(f"{varname} = {line}")
  80. return varname
  81. return inner
  82. def getvalue(self, result):
  83. self.output.writeline(f"return {result}")
  84. return self.output.getvalue()
  85. class WrapperHandler:
  86. def __init__(self, inner):
  87. self._inner = inner
  88. def __getattr__(self, item):
  89. return getattr(self._inner, item)
  90. MockHandler._init_cls()
  91. ops = Virtualized("ops", MockHandler)
  92. _graph = Virtualized("graph", NullHandler)
  93. _fake_mode = Virtualized("fake_mode", NullHandler)
  94. _kernel = Virtualized("kernel", NullHandler)
  95. _debug = Virtualized("debug", NullHandler)
  96. _interpreter = Virtualized("interpreter", NullHandler)
  97. class _V:
  98. MockHandler = MockHandler
  99. KernelFormatterHandler = KernelFormatterHandler
  100. WrapperHandler = WrapperHandler
  101. set_ops_handler = ops._set_handler
  102. get_ops_handler = ops._get_handler
  103. set_graph_handler = _graph._set_handler
  104. set_fake_mode = _fake_mode._set_handler
  105. set_kernel_handler = _kernel._set_handler
  106. set_debug_handler = _debug._set_handler
  107. set_interpreter_handler = _interpreter._set_handler
  108. @property
  109. def ops(self) -> MockHandler:
  110. """The operator handler specific to the current codegen task"""
  111. return ops._get_handler()
  112. @property
  113. def graph(self):
  114. """The graph currently being generated"""
  115. return _graph._get_handler()
  116. @property
  117. def fake_mode(self):
  118. """The graph currently being generated"""
  119. return _fake_mode._get_handler()
  120. @property
  121. def kernel(self):
  122. """The kernel currently being generated"""
  123. return _kernel._get_handler()
  124. @property
  125. def debug(self):
  126. return _debug._get_handler()
  127. @property
  128. def interpreter(self):
  129. return _interpreter._get_handler()
  130. V = _V()