codegen.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. import collections
  2. import dataclasses
  3. import re
  4. import types
  5. from typing import List
  6. import torch.nn
  7. from .bytecode_transformation import (
  8. create_dup_top,
  9. create_instruction,
  10. create_rot_n,
  11. Instruction,
  12. )
  13. from .exc import unimplemented
  14. from .source import AttrSource, Source
  15. from .utils import is_safe_constant, istype, rot_n_helper
  16. from .variables.base import VariableTracker
  17. from .variables.nn_module import NNModuleVariable
  18. from .variables.tensor import (
  19. SymNodeVariable,
  20. TensorVariable,
  21. TensorWithTFOverrideVariable,
  22. UnspecializedPythonVariable,
  23. )
  24. @dataclasses.dataclass
  25. class GraphOutputEntry:
  26. index: int
  27. variable: VariableTracker
  28. def merge(self, other: VariableTracker):
  29. # merge in any extra guards
  30. self.variable = self.variable.add_options(other)
  31. class PyCodegen:
  32. """
  33. Helper class uses for constructing Python bytecode
  34. """
  35. def __init__(
  36. self,
  37. tx=None,
  38. root: torch.nn.Module = None,
  39. graph_output_var: str = None,
  40. tempvars=None,
  41. ):
  42. self.root = root
  43. self.top_of_stack = None
  44. self.uses = collections.Counter()
  45. self.graph_outputs = collections.OrderedDict()
  46. self._output: List[Instruction] = []
  47. self.tempvars = tempvars or {}
  48. self.tx = tx
  49. self.graph_output_var = graph_output_var
  50. self.code_options = self.tx.output.code_options
  51. self.cell_and_freevars = self.tx.cell_and_freevars
  52. self.new_var = self.tx.output.new_var
  53. def graph_output_vars(self):
  54. return [x.variable for x in self.graph_outputs.values()]
  55. def __call__(self, value, allow_cache=True):
  56. """Generate code such that top-of-stack (TOS) is set to value"""
  57. if isinstance(value, Source):
  58. self._output.extend(value.reconstruct(self))
  59. self.clear_tos()
  60. return
  61. self.tx.output.guards.update(value.guards)
  62. assert isinstance(value, VariableTracker)
  63. output = self._output
  64. graph_outputs = self.graph_outputs
  65. if self.top_of_stack is value:
  66. output.append(create_dup_top())
  67. return
  68. if allow_cache:
  69. if value.mutable_local and value.mutable_local in self.tempvars:
  70. output.append(self.create_load(self.tempvars[value.mutable_local]))
  71. self.top_of_stack = value
  72. return
  73. if self.tempvars.get(value) is not None:
  74. output.append(self.create_load(self.tempvars[value]))
  75. self.top_of_stack = value
  76. return
  77. if value.source is not None and allow_cache:
  78. output.extend(value.source.reconstruct(self))
  79. elif value.is_python_constant() and is_safe_constant(
  80. value.as_python_constant()
  81. ):
  82. output.append(self.create_load_const(value.as_python_constant()))
  83. elif isinstance(
  84. value,
  85. (
  86. TensorVariable,
  87. SymNodeVariable,
  88. TensorWithTFOverrideVariable,
  89. UnspecializedPythonVariable,
  90. ),
  91. ):
  92. if isinstance(value, TensorWithTFOverrideVariable):
  93. # unwrap back to tensor
  94. value = value.tensor_variable
  95. graph_outputs_key = id(value.proxy)
  96. if graph_outputs_key not in graph_outputs:
  97. graph_outputs[graph_outputs_key] = GraphOutputEntry(
  98. len(graph_outputs), value
  99. )
  100. else:
  101. graph_outputs[graph_outputs_key].merge(value)
  102. output.append(self.create_load(self.graph_output_var))
  103. output.append(
  104. self._create_load_const(graph_outputs[graph_outputs_key].index)
  105. )
  106. output.append(create_instruction("BINARY_SUBSCR"))
  107. if isinstance(value, UnspecializedPythonVariable) and value.need_unwrap:
  108. output.extend(
  109. [
  110. self.create_load_attr("item"),
  111. create_instruction("CALL_FUNCTION", 0),
  112. ]
  113. )
  114. elif isinstance(value, NNModuleVariable):
  115. parts = value.module_key.split(".")
  116. if parts[0] in self.code_options["co_varnames"]:
  117. output.append(self.create_load(parts[0]))
  118. parts = parts[1:]
  119. else:
  120. assert self.root is not None
  121. output.append(self.create_load_output(self.root))
  122. for part in parts:
  123. output.append(self.create_load_attr(part))
  124. else:
  125. self.uses[value] += 1
  126. try:
  127. output.extend(value.reconstruct(self))
  128. except NotImplementedError:
  129. unimplemented(f"reconstruct: {value}")
  130. if allow_cache and value in self.tempvars:
  131. self._output.append(create_dup_top())
  132. self.add_cache(value)
  133. self.top_of_stack = value
  134. def add_cache(self, value):
  135. var = self.new_var()
  136. self.tempvars[value] = var
  137. if value.mutable_local:
  138. self.tempvars[value.mutable_local] = var
  139. self._output.append(self.create_store(var))
  140. def foreach(self, items):
  141. for i in items:
  142. self(i)
  143. def setup_globally_cached(self, name, value):
  144. """Store value in a new global"""
  145. name = re.sub(r"[^a-zA-Z0-9_]+", "_", name)
  146. f_globals = self.tx.f_globals
  147. if name in f_globals:
  148. assert id(f_globals[name]) == id(value)
  149. else:
  150. f_globals[name] = value
  151. return [self.create_load_global(name, add=True)]
  152. def clear_tos(self):
  153. self.top_of_stack = None
  154. def append_output(self, inst):
  155. assert isinstance(inst, Instruction)
  156. self._output.append(inst)
  157. self.clear_tos()
  158. def extend_output(self, insts):
  159. assert all(isinstance(x, Instruction) for x in insts)
  160. self._output.extend(insts)
  161. self.clear_tos()
  162. def get_instructions(self):
  163. return self._output
  164. def create_load(self, name):
  165. if name in self.cell_and_freevars():
  166. return create_instruction(
  167. "LOAD_DEREF", self.cell_and_freevars().index(name), name
  168. )
  169. assert name in self.code_options["co_varnames"], f"{name} missing"
  170. return create_instruction(
  171. "LOAD_FAST", self.code_options["co_varnames"].index(name), name
  172. )
  173. def create_load_closure(self, name):
  174. assert name in self.cell_and_freevars()
  175. return create_instruction(
  176. "LOAD_CLOSURE", self.cell_and_freevars().index(name), name
  177. )
  178. def create_store(self, name):
  179. if name in self.cell_and_freevars():
  180. return create_instruction(
  181. "STORE_DEREF", self.cell_and_freevars().index(name), name
  182. )
  183. assert name in self.code_options["co_varnames"]
  184. return create_instruction(
  185. "STORE_FAST", self.code_options["co_varnames"].index(name), name
  186. )
  187. def create_load_global(self, name, add=False):
  188. if add:
  189. self.tx.output.update_co_names(name)
  190. assert name in self.code_options["co_names"], f"{name} not in co_names"
  191. return create_instruction(
  192. "LOAD_GLOBAL", self.code_options["co_names"].index(name), name
  193. )
  194. def create_load_const(self, value):
  195. assert is_safe_constant(value), f"unsafe constant {value}"
  196. return self._create_load_const(value)
  197. @staticmethod
  198. def get_const_index(code_options, value):
  199. co_consts = code_options["co_consts"]
  200. assert istype(co_consts, tuple)
  201. index = None
  202. for i, v in enumerate(co_consts):
  203. if type(v) is type(value) and v == value:
  204. index = i
  205. break
  206. if index is None:
  207. index = len(co_consts)
  208. co_consts = co_consts + (value,)
  209. code_options["co_consts"] = co_consts
  210. return index
  211. def _create_load_const(self, value):
  212. index = self.get_const_index(self.code_options, value)
  213. return create_instruction("LOAD_CONST", index, value)
  214. create_load_output = _create_load_const
  215. def create_load_attr(self, name):
  216. if name not in self.code_options["co_names"]:
  217. self.code_options["co_names"] = self.code_options["co_names"] + (name,)
  218. return create_instruction(
  219. "LOAD_ATTR", self.code_options["co_names"].index(name), name
  220. )
  221. def create_load_attrs(self, names):
  222. return [self.create_load_attr(name) for name in names.split(".")]
  223. def load_function_name(self, fn_name, num_on_stack=0):
  224. """Load the global fn_name on the stack num_on_stack down"""
  225. return [self.create_load_global(fn_name, add=True)] + self.rot_n(
  226. num_on_stack + 1
  227. )
  228. def rot_n(self, n):
  229. try:
  230. return create_rot_n(n)
  231. except AttributeError:
  232. # desired rotate bytecode doesn't exist, generate equivalent bytecode
  233. return (
  234. [
  235. create_instruction("BUILD_TUPLE", n),
  236. self._create_load_const(rot_n_helper(n)),
  237. ]
  238. + create_rot_n(2)
  239. + [
  240. create_instruction("CALL_FUNCTION_EX", 0),
  241. create_instruction("UNPACK_SEQUENCE", n),
  242. ]
  243. )
  244. def make_function_with_closure(
  245. self, fn_name: str, code: types.CodeType, num_on_stack=0
  246. ):
  247. freevars = code.co_freevars
  248. assert freevars
  249. output = self._output
  250. for var in freevars:
  251. assert var in self.cell_and_freevars()
  252. output.append(
  253. create_instruction(
  254. "LOAD_CLOSURE", self.cell_and_freevars().index(var), var
  255. )
  256. )
  257. output.append(create_instruction("BUILD_TUPLE", len(freevars)))
  258. output.append(self.create_load_const(code))
  259. output.append(self.create_load_const(fn_name))
  260. output.append(create_instruction("MAKE_FUNCTION", 0x08))
  261. output.extend(self.rot_n(num_on_stack + 1))
  262. self.clear_tos()
  263. def create_load_python_module(self, mod):
  264. """
  265. Generate a LOAD_GLOBAL instruction to fetch a given python module.
  266. """
  267. root_globals = self.tx.output.root_globals
  268. name = re.sub(r"^.*[.]", "", mod.__name__)
  269. if root_globals.get(name, None) is mod:
  270. return self.create_load_global(name, add=True)
  271. mangled_name = f"___module_{name}_{id(mod)}"
  272. if mangled_name not in root_globals:
  273. self.tx.output.install_global(mangled_name, mod)
  274. return self.create_load_global(mangled_name, add=True)
  275. def make_call_generated_code(self, fn_name: str) -> List[Instruction]:
  276. """Call the generated code function stored in fn_name"""
  277. self.extend_output(self.load_function_name(fn_name))
  278. graphargs = self.tx.output.graphargs
  279. for arg in graphargs:
  280. if arg.is_unspecialized:
  281. self.extend_output(
  282. [
  283. self.create_load_python_module(torch),
  284. self.create_load_attr("tensor"),
  285. ]
  286. )
  287. self.extend_output(arg.load(self))
  288. self.extend_output(
  289. [
  290. create_instruction("CALL_FUNCTION", 1),
  291. ]
  292. )
  293. else:
  294. self.extend_output(arg.load(self))
  295. self.append_output(create_instruction("CALL_FUNCTION", len(graphargs)))
  296. def load_import_from(self, module_name, object_name):
  297. self.extend_output(
  298. AttrSource(self.tx.import_source(module_name), object_name).reconstruct(
  299. self
  300. )
  301. )
  302. def create_begin_finally(self):
  303. return create_instruction("BEGIN_FINALLY")