functions.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. import abc
  2. import enum
  3. import functools
  4. import inspect
  5. import itertools
  6. import types
  7. from typing import Dict, List
  8. import torch
  9. from .. import variables
  10. from ..bytecode_transformation import create_instruction
  11. from ..exc import unimplemented
  12. from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
  13. from ..utils import istensor, istype, make_cell
  14. from .base import typestr, VariableTracker
  15. def wrap_bound_arg(tx, val, options, source=None):
  16. # Source propagation is best effort since not every object we encounter has a source to begin with.
  17. assert (
  18. "source" not in options
  19. ), "Source needs to be separate from options due to recursive calls for lists/dicts"
  20. if isinstance(val, dict):
  21. return variables.ConstDictVariable(
  22. {
  23. k: wrap_bound_arg(tx, v, options, source=getattr(v, "source", None))
  24. for k, v in val.items()
  25. },
  26. dict,
  27. **options,
  28. )
  29. elif isinstance(val, (tuple, list)):
  30. cls = variables.BaseListVariable.cls_for(type(val))
  31. return cls(
  32. [
  33. wrap_bound_arg(tx, x, options, source=getattr(x, "source", None))
  34. for x in val
  35. ],
  36. **options,
  37. )
  38. if variables.ConstantVariable.is_literal(val) or istype(
  39. val, (torch.Size, torch.device, torch.dtype)
  40. ):
  41. return variables.ConstantVariable(val, **options)
  42. elif isinstance(val, types.FunctionType):
  43. return variables.UserFunctionVariable(val, source=source, **options)
  44. elif isinstance(val, enum.Enum):
  45. return variables.EnumVariable(val, source=source, **options)
  46. elif isinstance(val, (type, abc.ABCMeta)):
  47. return variables.UserDefinedClassVariable(val, source=source, **options)
  48. elif istensor(val):
  49. from torch._dynamo.variables.builder import VariableBuilder
  50. return VariableBuilder(tx, source=source, **options)(val)
  51. else:
  52. assert isinstance(val, VariableTracker), typestr(val)
  53. return val
  54. def wrap_args_kwargs(tx, result, options):
  55. for k, v in list(result.items()):
  56. if isinstance(v, (tuple, dict)):
  57. # args/kwargs
  58. result[k] = wrap_bound_arg(tx, v, options)
  59. def init_cellvars(parent, result, code):
  60. closure_cells = dict()
  61. side_effects = parent.output.side_effects
  62. for name in code.co_cellvars:
  63. closure_cells[name] = side_effects.track_cell_new()
  64. if name in result:
  65. side_effects.store_cell(closure_cells[name], result.pop(name))
  66. return closure_cells
  67. class BaseUserFunctionVariable(VariableTracker):
  68. def get_filename(self):
  69. return self.get_code().co_filename
  70. def get_name(self):
  71. return self.get_code().co_name
  72. def call_function(
  73. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  74. ) -> "VariableTracker":
  75. return tx.inline_user_function_return(
  76. self, list(self.self_args()) + list(args), kwargs
  77. )
  78. def num_parameters(self):
  79. return len(inspect.signature(self.get_function()).parameters)
  80. def closure_vars(self, tx):
  81. return {}
  82. class UserFunctionVariable(BaseUserFunctionVariable):
  83. """Some unsupported user-defined global function"""
  84. def __init__(self, fn, is_constant=False, **kwargs):
  85. super().__init__(**kwargs)
  86. if getattr(fn, "_dynamo_marked_constant", False):
  87. # This method should be treated as a constant for the purposes of compilation
  88. self.is_constant = True
  89. else:
  90. self.is_constant = False
  91. assert isinstance(
  92. fn, (types.FunctionType, torch.jit.ScriptFunction)
  93. ), f"expected FunctionType found {typestr(fn)} {fn}"
  94. # unpack @torch._dynamo.optimize()(fn) wrapped function
  95. fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
  96. # unpack torch.jit.script_if_tracing
  97. if inspect.getattr_static(fn, "__script_if_tracing_wrapper", False):
  98. fn = inspect.getattr_static(fn, "__original_fn", fn)
  99. self.fn: types.FunctionType = fn
  100. def self_args(self):
  101. return []
  102. def get_function(self):
  103. return self.fn
  104. def get_code(self):
  105. return self.fn.__code__
  106. def python_type(self):
  107. return types.FunctionType
  108. def has_self(self):
  109. return getattr(self.fn, "__self__", None) is not None
  110. def get_globals(self):
  111. return self.fn.__globals__
  112. def bind_args(self, parent, args, kwargs):
  113. assert not self.is_constant
  114. options = VariableTracker.propagate([self])
  115. tx = parent.output.root_tx
  116. wrap = functools.partial(wrap_bound_arg, tx=tx, options=options)
  117. fn: types.FunctionType = self.fn
  118. defaults = fn.__defaults__ or []
  119. defaults_sources = [
  120. None if self.source is None else DefaultsSource(self.source, idx)
  121. for idx, _ in enumerate(defaults)
  122. ]
  123. fake_func = types.FunctionType(
  124. fn.__code__,
  125. fn.__globals__,
  126. fn.__name__,
  127. tuple(
  128. [
  129. wrap(val=arg, source=source)
  130. for arg, source in zip(defaults, defaults_sources)
  131. ]
  132. ),
  133. fn.__closure__,
  134. )
  135. if fn.__kwdefaults__:
  136. kwdefaults_sources = {
  137. k: None
  138. if self.source is None
  139. else DefaultsSource(self.source, k, is_kw=True)
  140. for k in fn.__kwdefaults__
  141. }
  142. fake_func.__kwdefaults__ = {
  143. k: wrap(val=v, source=kwdefaults_sources[k])
  144. for k, v in fn.__kwdefaults__.items()
  145. }
  146. bound = inspect.signature(fake_func).bind(*args, **kwargs)
  147. bound.apply_defaults()
  148. result = dict(bound.arguments.items())
  149. wrap_args_kwargs(tx, result, options)
  150. closure_cells = init_cellvars(parent, result, fn.__code__)
  151. closure = self.fn.__closure__ or ()
  152. assert len(closure) == len(self.fn.__code__.co_freevars)
  153. for idx, name, cell in zip(
  154. itertools.count(), self.fn.__code__.co_freevars, closure
  155. ):
  156. if name == "__class__":
  157. source = AttrSource(self.source, "__class__") if self.source else None
  158. result[name] = variables.UserDefinedClassVariable(
  159. cell.cell_contents,
  160. source=source,
  161. )
  162. else:
  163. var = tx.match_nested_cell(name, cell)
  164. if var is not None:
  165. # optimization for cleaner codegen
  166. result[name] = var
  167. elif self.source:
  168. from .builder import VariableBuilder
  169. side_effects = parent.output.side_effects
  170. if cell in side_effects:
  171. out = side_effects[cell]
  172. else:
  173. closure_cell = GetItemSource(
  174. AttrSource(self.source, "__closure__"), idx
  175. )
  176. closure_cell_contents = AttrSource(
  177. closure_cell, "cell_contents"
  178. )
  179. contents_var = VariableBuilder(parent, closure_cell_contents)(
  180. cell.cell_contents
  181. )
  182. if (
  183. closure_cell_contents.name()
  184. not in tx.mutated_closure_cell_contents
  185. ):
  186. # Optimistically don't allocate the cell, to
  187. # reduce the number of side effects. This is
  188. # important for cond, as without it, any accesses
  189. # to closures create side effects and cond doesn't
  190. # support side effects. If we're wrong and this
  191. # closure cell gets written to, we will restart
  192. # the analysis with this cell's name in the
  193. # mutated list here
  194. result[name] = contents_var
  195. continue
  196. # cells are written to with "cell_contents",
  197. # so the source should just be the closure_cell, not its contents
  198. out = side_effects.track_cell_existing(closure_cell, cell)
  199. side_effects.store_cell(
  200. out,
  201. contents_var,
  202. )
  203. result[name] = out
  204. else:
  205. unimplemented("inline with __closure__")
  206. return result, closure_cells
  207. def export_freevars(self, parent, child):
  208. pass
  209. def call_function(
  210. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  211. ) -> "VariableTracker":
  212. if self.is_constant:
  213. options = VariableTracker.propagate(self, args, kwargs.values())
  214. return invoke_and_store_as_constant(
  215. tx, self.fn, self.get_name(), options, args, kwargs
  216. )
  217. return super().call_function(tx, args, kwargs)
  218. class UserMethodVariable(UserFunctionVariable):
  219. """Some unsupported user-defined method"""
  220. def __init__(self, fn, obj, **kwargs):
  221. super().__init__(fn=fn, **kwargs)
  222. self.obj = obj
  223. def __str__(self):
  224. return f"{self.__class__.__name__}({self.fn}, {self.obj})"
  225. def self_args(self):
  226. return [self.obj]
  227. def python_type(self):
  228. return types.MethodType
  229. def call_function(
  230. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  231. ) -> "VariableTracker":
  232. if isinstance(self.obj, variables.NNModuleVariable):
  233. module_attr = getattr(self.fn, "__module__", "")
  234. if (
  235. module_attr is not None
  236. and module_attr.startswith("torch.nn.")
  237. or self.is_constant
  238. ):
  239. return self.obj.call_method(
  240. tx, self.fn.__name__, args, kwargs, constant=self.is_constant
  241. ).add_options(self)
  242. return super().call_function(tx, args, kwargs)
  243. def num_parameters(self):
  244. return super().num_parameters() - 1
  245. class WrappedUserMethodVariable(UserMethodVariable):
  246. def __init__(self, wrapped, context, **kwargs):
  247. kwargs.pop("fn", None)
  248. kwargs.pop("obj", None)
  249. super().__init__(wrapped.fn, wrapped.obj, **kwargs)
  250. self.wrapped = wrapped
  251. self.context = context
  252. def call_function(
  253. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  254. ) -> "VariableTracker":
  255. self.context.enter(tx)
  256. result = super().call_function(tx, args, kwargs)
  257. self.context.exit(tx)
  258. return result
  259. class WrappedUserFunctionVariable(UserFunctionVariable):
  260. def __init__(self, wrapped, context, **kwargs):
  261. kwargs.pop("fn", None)
  262. kwargs.pop("obj", None)
  263. super().__init__(wrapped.fn, **kwargs)
  264. self.wrapped = wrapped
  265. self.context = context
  266. def call_function(
  267. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  268. ) -> "VariableTracker":
  269. self.context.enter(tx)
  270. result = super().call_function(tx, args, kwargs)
  271. self.context.exit(tx)
  272. return result
  273. def invoke_and_store_as_constant(tx, fn, name, options, args, kwargs):
  274. def convert(x):
  275. if isinstance(x, variables.TensorVariable):
  276. return x.get_real_value()
  277. return x.as_python_constant()
  278. args = [convert(x) for x in args]
  279. kwargs = {k: convert(v) for k, v in kwargs.items()}
  280. res = fn(*args, **kwargs)
  281. return tx.output.register_attr_or_module(
  282. res,
  283. name,
  284. source=ConstantSource(name),
  285. **options,
  286. )
  287. class NestedUserFunctionVariable(BaseUserFunctionVariable):
  288. def __init__(
  289. self,
  290. fn_name,
  291. code,
  292. f_globals,
  293. defaults,
  294. kwdefaults,
  295. annotations,
  296. closure,
  297. closure_scope,
  298. **kwargs,
  299. ):
  300. super().__init__(**kwargs)
  301. assert isinstance(fn_name.as_python_constant(), str)
  302. assert isinstance(code.as_python_constant(), types.CodeType)
  303. assert isinstance(f_globals, dict)
  304. self.fn_name = fn_name
  305. self.code = code
  306. self.f_globals = f_globals
  307. self.defaults = defaults
  308. self.kwdefaults = kwdefaults
  309. self.annotations = annotations
  310. self.closure = closure
  311. if closure is None:
  312. closure_scope = None
  313. self.closure_scope = closure_scope
  314. def self_args(self):
  315. return []
  316. def get_code(self):
  317. return self.code.as_python_constant()
  318. def get_function(self):
  319. if self.closure:
  320. raise NotImplementedError()
  321. func = types.FunctionType(
  322. self.code.as_python_constant(),
  323. self.f_globals,
  324. self.fn_name.as_python_constant(),
  325. )
  326. if self.defaults:
  327. func.__defaults__ = self.defaults.as_python_constant()
  328. if self.kwdefaults:
  329. func.__kwdefaults__ = self.kwdefaults.as_python_constant()
  330. if self.annotations:
  331. annotations = self.annotations.as_python_constant()
  332. if isinstance(annotations, tuple):
  333. from itertools import pairwise
  334. annotations = dict(pairwise(annotations))
  335. # TypeError: __annotations__ must be set to a dict object
  336. assert isinstance(annotations, dict)
  337. func.__annotations__ = annotations
  338. return func
  339. def has_closure(self):
  340. return self.closure is not None
  341. def has_self(self):
  342. return False
  343. def get_globals(self):
  344. return self.f_globals
  345. def bind_args(self, parent, args, kwargs):
  346. code = self.get_code()
  347. func = types.FunctionType(
  348. code,
  349. self.f_globals,
  350. self.fn_name.as_python_constant(),
  351. tuple(self.defaults.items) if self.defaults else None,
  352. tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))),
  353. )
  354. if self.kwdefaults:
  355. func.__kwdefaults__ = self.kwdefaults.items
  356. bound = inspect.signature(func).bind(*args, **kwargs)
  357. bound.apply_defaults()
  358. result = dict(bound.arguments.items())
  359. wrap_args_kwargs(parent.output.root_tx, result, VariableTracker.propagate(self))
  360. closure_cells = init_cellvars(parent, result, code)
  361. for idx, name in enumerate(code.co_freevars):
  362. assert getattr(self.closure.items[idx], name, name) == name
  363. assert name not in result
  364. closure_cells[name] = self.closure.items[idx]
  365. return result, closure_cells
  366. def export_freevars(self, parent, child):
  367. code = self.get_code()
  368. for var in code.co_freevars:
  369. if var in child.symbolic_locals:
  370. parent.symbolic_locals[var] = child.symbolic_locals[var]
  371. def reconstruct(self, codegen):
  372. flags = 0x00
  373. if self.defaults:
  374. flags |= 0x01
  375. codegen(self.defaults)
  376. if self.kwdefaults:
  377. flags |= 0x02
  378. codegen(self.kwdefaults)
  379. if isinstance(self.annotations, variables.ConstDictVariable) or isinstance(
  380. self.annotations, variables.TupleVariable
  381. ):
  382. flags |= 0x04
  383. try:
  384. if isinstance(self.annotations, variables.ConstDictVariable):
  385. annotations = {
  386. k: v.as_python_constant()
  387. for k, v in self.annotations.items.items()
  388. }
  389. else:
  390. annotations = tuple(
  391. [v.as_python_constant() for v in self.annotations.items]
  392. )
  393. codegen.extend_output([codegen._create_load_const(annotations)])
  394. except NotImplementedError:
  395. codegen(self.annotations)
  396. if self.closure:
  397. flags |= 0x08
  398. codegen(self.closure)
  399. codegen(self.code)
  400. codegen(self.fn_name)
  401. return [create_instruction("MAKE_FUNCTION", flags)]