misc.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853
  1. import inspect
  2. import sys
  3. import types
  4. from typing import Dict, List
  5. import torch._C
  6. from torch._guards import Guard, GuardSource
  7. from .. import variables
  8. from ..bytecode_transformation import create_instruction
  9. from ..exc import unimplemented
  10. from ..guards import GuardBuilder
  11. from ..source import AttrSource
  12. from ..utils import identity, proxy_args_kwargs
  13. from .base import VariableTracker
  14. from .functions import (
  15. NestedUserFunctionVariable,
  16. UserFunctionVariable,
  17. UserMethodVariable,
  18. WrappedUserFunctionVariable,
  19. WrappedUserMethodVariable,
  20. )
  21. class SuperVariable(VariableTracker):
  22. def __init__(self, typevar, objvar=None, specialized=False, **kwargs):
  23. super().__init__(**kwargs)
  24. self.typevar = typevar
  25. self.objvar = objvar
  26. self.specialized = specialized # directly get attr from self.typevar if true
  27. def reconstruct(self, codegen):
  28. codegen(variables.BuiltinVariable(super))
  29. codegen(self.typevar)
  30. if self.objvar is not None:
  31. codegen(self.objvar)
  32. return [create_instruction("CALL_FUNCTION", 2)]
  33. else:
  34. return [create_instruction("CALL_FUNCTION", 1)]
  35. def const_getattr(self, tx, name):
  36. assert self.objvar, "1-arg super not implemented"
  37. if self.specialized:
  38. return getattr(self.typevar.as_python_constant(), name)
  39. search_type = self.typevar.as_python_constant()
  40. # We default to the python type of the object. However, if this is
  41. # a `type` or subclass of `type`, then the original object represents
  42. # the user defined type.
  43. type_to_use = self.objvar.python_type()
  44. if issubclass(type_to_use, type):
  45. type_to_use = self.objvar.value
  46. # TODO(jansel): there is a small chance this could trigger user code, prevent that
  47. return getattr(super(search_type, type_to_use), name)
  48. def call_method(
  49. self,
  50. tx,
  51. name,
  52. args: "List[VariableTracker]",
  53. kwargs: "Dict[str, VariableTracker]",
  54. ) -> "VariableTracker":
  55. options = VariableTracker.propagate(
  56. self, args, kwargs.values(), self.objvar, self.typevar
  57. )
  58. inner_fn = self.const_getattr(self, name)
  59. source = None if self.source is None else AttrSource(self.source, name)
  60. if inner_fn is object.__init__:
  61. return LambdaVariable(identity, **options)
  62. elif isinstance(inner_fn, types.FunctionType):
  63. return variables.UserFunctionVariable(
  64. inner_fn, source=source, **options
  65. ).call_function(tx, [self.objvar] + args, kwargs)
  66. elif isinstance(inner_fn, types.MethodType):
  67. return variables.UserMethodVariable(
  68. inner_fn.__func__, self.objvar, source=source, **options
  69. ).call_function(tx, args, kwargs)
  70. else:
  71. unimplemented(f"non-function or method super: {inner_fn}")
  72. class UnknownVariable(VariableTracker):
  73. """
  74. It could be anything!
  75. """
  76. class ComptimeVariable(VariableTracker):
  77. """
  78. This variable is special, it lets you execute arbitrary code at
  79. Dynamo compile time
  80. """
  81. def reconstruct(self, codegen):
  82. raise NotImplementedError("comptime is special form")
  83. def var_getattr(self, tx, name: str) -> "VariableTracker":
  84. from ..comptime import comptime
  85. # To support the comptime.print_graph convenience accessors
  86. from .functions import UserFunctionVariable
  87. return UserFunctionVariable(
  88. getattr(comptime, name), source=AttrSource(self.source, name)
  89. )
  90. def call_function(
  91. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  92. ) -> "VariableTracker":
  93. from ..comptime import ComptimeContext
  94. # TODO: support an expression form as well
  95. assert not kwargs
  96. assert len(args) == 1
  97. fn = args[0]
  98. if isinstance(fn, UserFunctionVariable):
  99. fn.get_function()(ComptimeContext(tx))
  100. elif isinstance(fn, NestedUserFunctionVariable):
  101. # We have to manually bind the freevars ourselves
  102. code = fn.get_code()
  103. assert not fn.closure, (
  104. "comptime function must not have free variables, "
  105. f"but these variables were free: {code.co_freevars}"
  106. )
  107. func = types.FunctionType(
  108. code,
  109. fn.f_globals,
  110. fn.fn_name.as_python_constant(),
  111. tuple(fn.defaults.items) if fn.defaults else None,
  112. # We could automatically promote free variables into
  113. # ComptimeVar but this is confusing if you access
  114. # a free variable that we actually DO have the runtime
  115. # value for
  116. # tuple(make_cell(ComptimeVar(i)) for i in fn.closure.items)
  117. tuple(),
  118. )
  119. func(ComptimeContext(tx))
  120. else:
  121. raise RuntimeError(f"unsupported argument to comptime: {type(fn)}")
  122. return variables.ConstantVariable(None)
  123. class ClosureVariable(UnknownVariable):
  124. def __init__(self, name, **kwargs):
  125. super().__init__(**kwargs)
  126. self.name = name
  127. def reconstruct(self, codegen):
  128. return [codegen.create_load_closure(self.name)]
  129. class NewCellVariable(VariableTracker):
  130. def __init__(self, **kwargs):
  131. super().__init__(**kwargs)
  132. class NewGlobalVariable(VariableTracker):
  133. def __init__(self, **kwargs):
  134. super().__init__(**kwargs)
  135. class ContextWrappingVariable(VariableTracker):
  136. def __init__(self, target_values, initial_values=None, **kwargs):
  137. super().__init__(**kwargs)
  138. self.target_values = target_values
  139. self.initial_values = initial_values
  140. self.recursively_contains = (
  141. set()
  142. ) # This var doesn't contain any child vars and doesn't support clone() properly,
  143. # so don't populate this automatically
  144. def enter(self, tx):
  145. self._call_func(tx, self.target_values)
  146. return variables.ConstantVariable(None, **VariableTracker.propagate(self))
  147. def exit(self, tx, *args):
  148. self._call_func(tx, self.initial_values)
  149. return variables.ConstantVariable(None, **VariableTracker.propagate(self))
  150. def reconstruct(self, codegen, target_inst=None):
  151. """
  152. Generate following Python Bytecode, with a `torch._C._set_grad_enable` call
  153. Python 3.8
  154. 0 LOAD_GLOBAL 0 (torch)
  155. 2 LOAD_ATTR 1 (_C)
  156. 4 LOAD_METHOD 2 (_set_grad_enable)
  157. 6 LOAD_CONST 1 (False)
  158. 8 CALL_METHOD 1
  159. 10 POP_TOP
  160. 12 SETUP_FINALLY 10 (to 24)
  161. 14 LOAD_GLOBAL 3 (user_inst)
  162. 16 CALL_FUNCTION 0
  163. 18 POP_TOP
  164. 20 POP_BLOCK
  165. 22 BEGIN_FINALLY
  166. 24 LOAD_GLOBAL 0 (torch)
  167. 26 LOAD_ATTR 1 (_C)
  168. 28 LOAD_METHOD 2 (_set_grad_enable)
  169. 30 LOAD_CONST 2 (True)
  170. 32 CALL_METHOD 1
  171. 34 POP_TOP
  172. 36 END_FINALLY
  173. 38 LOAD_CONST 0 (None)
  174. 40 RETURN_VALUE
  175. Instructions 0-10 and 24-34 call torch._C.set_grad_enable(True/False)
  176. Python 3.9, 3.10
  177. 0 LOAD_GLOBAL 0 (torch)
  178. 2 LOAD_ATTR 1 (_C)
  179. 4 LOAD_METHOD 2 (_set_grad_enable)
  180. 6 LOAD_CONST 1 (False)
  181. 8 CALL_METHOD 1
  182. 10 POP_TOP
  183. 12 SETUP_FINALLY 22 (to 36)
  184. 14 LOAD_GLOBAL 3 (user_inst)
  185. 16 CALL_FUNCTION 0
  186. 18 POP_TOP
  187. 20 POP_BLOCK
  188. 22 LOAD_GLOBAL 0 (torch)
  189. 24 LOAD_ATTR 1 (_C)
  190. 26 LOAD_METHOD 2 (_set_grad_enable)
  191. 28 LOAD_CONST 2 (True)
  192. 30 CALL_METHOD 1
  193. 32 POP_TOP
  194. 34 JUMP_FORWARD 14 (to 50)
  195. 36 LOAD_GLOBAL 0 (torch)
  196. 38 LOAD_ATTR 1 (_C)
  197. 40 LOAD_METHOD 2 (_set_grad_enable)
  198. 42 LOAD_CONST 2 (True)
  199. 44 CALL_METHOD 1
  200. 46 POP_TOP
  201. 48 RERAISE
  202. 50 LOAD_CONST 0 (None)
  203. 52 RETURN_VALUE
  204. """
  205. if self.target_values == self.initial_values:
  206. return ([], [])
  207. def set_context_insts(values):
  208. attr_source = AttrSource(
  209. codegen.tx.import_source(self.module_name()), self.fn_name()
  210. )
  211. load_set_context_enabling_insts = attr_source.reconstruct(codegen)
  212. if values:
  213. loads = [codegen.create_load_const(val) for val in values]
  214. else:
  215. loads = []
  216. return [
  217. *load_set_context_enabling_insts,
  218. *loads,
  219. create_instruction("CALL_FUNCTION", len(loads)),
  220. create_instruction("POP_TOP"),
  221. ]
  222. init_block = set_context_insts(self.target_values)
  223. finally_block = set_context_insts(self.initial_values)
  224. setup_final_inst = create_instruction("SETUP_FINALLY", target=finally_block[0])
  225. prologue = init_block + [setup_final_inst]
  226. # Generate the epilogue - starts with 20 POP_BLOCK and ends at 34 POP_TOP
  227. if sys.version_info < (3, 9):
  228. # Generate the prologue that ends with setup_finally
  229. epilogue = [
  230. create_instruction("POP_BLOCK"),
  231. codegen.create_begin_finally(),
  232. *finally_block,
  233. create_instruction("END_FINALLY"),
  234. ]
  235. else:
  236. except_block = set_context_insts(self.initial_values)
  237. epilogue = [
  238. create_instruction("POP_BLOCK"),
  239. *except_block,
  240. create_instruction("JUMP_FORWARD", target=target_inst),
  241. *finally_block,
  242. create_instruction("RERAISE"),
  243. ]
  244. return (prologue, epilogue)
  245. def _call_func(self, tx, initial_values):
  246. raise NotImplementedError("_call_func called on base")
  247. def module_name(self):
  248. raise NotImplementedError("module_name called on base")
  249. def fn_name(self):
  250. raise NotImplementedError("fn_name called on base")
  251. def call_function(
  252. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  253. ) -> "VariableTracker":
  254. assert len(args) == 1
  255. if isinstance(args[0], NestedUserFunctionVariable):
  256. args[0] = UserFunctionVariable(args[0].get_function())
  257. assert isinstance(args[0], UserMethodVariable) or isinstance(
  258. args[0], UserFunctionVariable
  259. )
  260. if isinstance(args[0], UserMethodVariable):
  261. return WrappedUserMethodVariable(args[0], self)
  262. if isinstance(args[0], UserFunctionVariable):
  263. return WrappedUserFunctionVariable(args[0], self)
  264. class GradModeVariable(ContextWrappingVariable):
  265. """represents torch.{no_grad,enable_grad,set_grad_mode}()"""
  266. _guards_singleton = {Guard("", GuardSource.GLOBAL, GuardBuilder.GRAD_MODE)}
  267. @staticmethod
  268. def create(tx, target_value, **kwargs):
  269. var = GradModeVariable(
  270. target_values=[target_value],
  271. initial_values=[torch.is_grad_enabled()],
  272. **kwargs,
  273. )
  274. var._call_func(tx, [target_value])
  275. return var
  276. def __init__(self, target_values, initial_values=None, **kwargs):
  277. super().__init__(
  278. target_values=target_values, initial_values=initial_values, **kwargs
  279. )
  280. self.guards = self.guards | self._guards_singleton
  281. def enter(self, tx):
  282. return variables.ConstantVariable(None, **VariableTracker.propagate(self))
  283. def _call_func(self, tx, values):
  284. assert len(values) == 1
  285. value = values[0]
  286. tx.output.create_node(
  287. "call_function", torch._C._set_grad_enabled, (value,), {}
  288. ),
  289. torch._C._set_grad_enabled(value)
  290. def module_name(self):
  291. return "torch"
  292. def fn_name(self):
  293. return "set_grad_enabled"
  294. class AutocastModeVariable(ContextWrappingVariable):
  295. @staticmethod
  296. def create(target_values, kwargs):
  297. # device_type : str,
  298. # dtype : Optional[_dtype] = None,
  299. # enabled : bool = True,
  300. # cache_enabled : Optional[bool] = None):cache_enabled
  301. bound_args = inspect.signature(torch.autocast).bind(*target_values, **kwargs)
  302. bound_args.apply_defaults()
  303. target_values = []
  304. kwargs.clear()
  305. for key in ["device_type", "dtype", "enabled", "cache_enabled"]:
  306. arg = bound_args.arguments[key]
  307. if isinstance(arg, VariableTracker):
  308. target_values.append(bound_args.arguments[key].as_python_constant())
  309. else:
  310. target_values.append(bound_args.arguments[key])
  311. var = AutocastModeVariable(target_values, initial_values=None, **kwargs)
  312. return var
  313. def __init__(self, target_values, initial_values=None, **kwargs):
  314. mode = kwargs.pop("mode", None)
  315. super().__init__(
  316. target_values=target_values, initial_values=initial_values, **kwargs
  317. )
  318. self.target_values = target_values
  319. self.mode = mode
  320. def exit(self, tx, *args):
  321. self.mode = tx.output.create_node(
  322. "call_function", exit_functional_autocast, (self.mode,), {}
  323. )
  324. def enter(self, tx):
  325. self.mode = tx.output.create_node(
  326. "call_function", enter_functional_autocast, (*self.target_values,), {}
  327. )
  328. def module_name(self):
  329. return "torch.amp.autocast_mode"
  330. def fn_name(self):
  331. return "autocast"
  332. def enter_functional_autocast(*vals):
  333. mode = torch.amp.autocast(*vals)
  334. mode.__enter__()
  335. return mode
  336. def exit_functional_autocast(mode):
  337. mode.__exit__(None, None, None)
  338. class NullContextVariable(ContextWrappingVariable):
  339. """
  340. This class represents Python contextlib.nullcontext.
  341. It's used as a placeholder for other context managers that Dynamo doesn't
  342. support yet, e.g, torch.autograd.profiler.record_function.
  343. """
  344. def __init__(self, target_values=None, **kwargs):
  345. super().__init__(target_values=target_values, **kwargs)
  346. def enter(self, tx):
  347. return variables.ConstantVariable(None, **VariableTracker.propagate(self))
  348. def exit(self, tx, *args):
  349. return variables.ConstantVariable(None, **VariableTracker.propagate(self))
  350. def module_name(self):
  351. return "contextlib"
  352. def fn_name(self):
  353. return "nullcontext"
  354. class CUDAStreamContextVariable(ContextWrappingVariable):
  355. @staticmethod
  356. def create(tx, target_value, **kwargs):
  357. from .builder import wrap_fx_proxy_cls
  358. current_stream = wrap_fx_proxy_cls(
  359. CUDAStreamVariable,
  360. tx,
  361. tx.output.create_proxy(
  362. "call_function",
  363. torch.cuda.current_stream,
  364. (None,),
  365. {},
  366. ),
  367. )
  368. return CUDAStreamContextVariable(
  369. target_values=[target_value],
  370. initial_values=[current_stream],
  371. **kwargs,
  372. )
  373. def __init__(self, target_values, initial_values=None, **kwargs):
  374. super().__init__(
  375. target_values=target_values, initial_values=initial_values, **kwargs
  376. )
  377. def enter(self, tx):
  378. tx.output.create_proxy(
  379. "call_function",
  380. torch.cuda.set_stream,
  381. (self.target_values[0].as_proxy(),),
  382. {},
  383. )
  384. torch.cuda.set_stream(self.target_values[0].value)
  385. def exit(self, tx, *args):
  386. tx.output.create_proxy(
  387. "call_function",
  388. torch.cuda.set_stream,
  389. (self.initial_values[0].as_proxy(),),
  390. {},
  391. )
  392. torch.cuda.set_stream(self.initial_values[0].value)
  393. def fn_name(self):
  394. return "cuda.stream"
  395. class CUDAStreamVariable(VariableTracker):
  396. def __init__(self, proxy, value, **kwargs):
  397. if "example_value" in proxy.node.meta:
  398. assert proxy.node.meta["example_value"] == value
  399. super().__init__(**kwargs)
  400. self.proxy = proxy
  401. self.value = value
  402. def call_method(
  403. self,
  404. tx,
  405. name,
  406. args: "List[VariableTracker]",
  407. kwargs: "Dict[str, VariableTracker]",
  408. ) -> "VariableTracker":
  409. unimplemented("cuda stream")
  410. def as_proxy(self):
  411. return self.proxy
  412. class WithExitFunctionVariable(VariableTracker):
  413. def __init__(self, ctx: ContextWrappingVariable, target, **kwargs):
  414. super().__init__(**kwargs)
  415. assert isinstance(ctx, ContextWrappingVariable)
  416. self.ctx = ctx
  417. self.target = target
  418. def call_function(
  419. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  420. ) -> "VariableTracker":
  421. assert not kwargs
  422. return self.ctx.exit(tx, *args)
  423. def reconstruct(self, codegen):
  424. # Note here we reconstruct the context manager rather than the
  425. # exit function. The handler generated by BlockStackEntry
  426. # will re-enter the context in the resume function.
  427. output = AttrSource(
  428. codegen.tx.import_source(self.ctx.module_name()), self.ctx.fn_name()
  429. ).reconstruct(codegen)
  430. if codegen.tx.output.partial_convert:
  431. loads = [codegen.create_load_const(val) for val in self.ctx.target_values]
  432. output.extend(loads)
  433. output.extend(
  434. [
  435. create_instruction("CALL_FUNCTION", len(loads)),
  436. create_instruction("SETUP_WITH", target=self.target),
  437. create_instruction("POP_TOP"),
  438. ]
  439. )
  440. return output
  441. class InspectSignatureVariable(VariableTracker):
  442. """represents inspect.signature(...)"""
  443. @staticmethod
  444. def create(callable, **kwargs):
  445. if kwargs:
  446. unimplemented(f"inspect.signature with {kwargs}")
  447. return InspectSignatureVariable(callable)
  448. def __init__(self, inspected, **kwargs):
  449. super().__init__(**kwargs)
  450. self.inspected = inspected
  451. class AutogradFunctionVariable(VariableTracker):
  452. """represents a torch.autograd.Function subclass"""
  453. def __init__(self, fn_cls, **kwargs):
  454. super().__init__(**kwargs)
  455. self.fn_cls = fn_cls
  456. def call_apply(self, tx, args, kwargs):
  457. requires_grad = False
  458. def visit(node):
  459. nonlocal requires_grad
  460. if isinstance(node, variables.TensorVariable):
  461. if node.requires_grad is not False:
  462. requires_grad = True
  463. if isinstance(node, variables.NNModuleVariable):
  464. if node.is_training(tx):
  465. requires_grad = True
  466. return node
  467. VariableTracker.apply(visit, (args, kwargs))
  468. if requires_grad and torch.is_grad_enabled():
  469. # TODO(jansel): handle this in training mode
  470. unimplemented("autograd.Function with requires_grad")
  471. args = [BlackHoleVariable()] + list(args)
  472. options = VariableTracker.propagate(self, args, kwargs.values())
  473. options["source"] = AttrSource(AttrSource(self.source, "__class__"), "forward")
  474. fn = self.fn_cls.forward
  475. if isinstance(fn, types.FunctionType):
  476. return variables.UserFunctionVariable(fn, **options).call_function(
  477. tx, args, kwargs
  478. )
  479. elif isinstance(fn, types.MethodType):
  480. return variables.UserMethodVariable(
  481. fn.__func__, variables.UserDefinedClassVariable(self.fn_cls), **options
  482. ).call_function(tx, args, kwargs)
  483. else:
  484. unimplemented(
  485. f"non-function or method in subclass of torch.autograd.Function: {fn}"
  486. )
  487. def call_function(self, tx, args, kwargs):
  488. options = VariableTracker.propagate(self, args, kwargs.values())
  489. return AutogradFunctionVariable(self.fn_cls, source=self.source, **options)
  490. class BlackHoleVariable(VariableTracker):
  491. """A autograd.function context that just ignores everything (for forward extraction)"""
  492. def call_method(
  493. self,
  494. tx,
  495. name,
  496. args: "List[VariableTracker]",
  497. kwargs: "Dict[str, VariableTracker]",
  498. ) -> "VariableTracker":
  499. assert name in ("__setattr__", "save_for_backward"), name
  500. return variables.ConstantVariable(
  501. None, **VariableTracker.propagate(self, args, kwargs.values())
  502. )
  503. class AutogradFunctionContextVariable(VariableTracker):
  504. """
  505. A autograd.function context used after graph break in forward.
  506. Any call method on this context object will be graph break.
  507. The is different from BlackHoleVariable which is only used in inference mode.
  508. """
  509. pass
  510. class LambdaVariable(VariableTracker):
  511. def __init__(self, fn, **kwargs):
  512. super().__init__(**kwargs)
  513. self.fn = fn
  514. def call_function(
  515. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  516. ) -> "VariableTracker":
  517. return self.fn(*args, **kwargs).add_options(self)
  518. class GetAttrVariable(VariableTracker):
  519. def __init__(self, obj, name, **kwargs):
  520. super().__init__(**kwargs)
  521. assert isinstance(obj, VariableTracker)
  522. assert isinstance(name, str)
  523. self.obj = obj
  524. self.name = name
  525. def __str__(self):
  526. return f"{self.__class__.__name__}({self.obj}, {self.name})"
  527. @staticmethod
  528. def create_getattr_proxy(base_proxy: torch.fx.Proxy, attr):
  529. return getattr(base_proxy, attr)
  530. def as_proxy(self):
  531. return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name)
  532. def const_getattr(self, tx, name):
  533. if not isinstance(self.obj, variables.NNModuleVariable):
  534. raise NotImplementedError()
  535. step1 = tx.output.get_submodule(self.obj.module_key)
  536. if self.name not in step1.__dict__:
  537. raise NotImplementedError()
  538. step2 = inspect.getattr_static(step1, self.name)
  539. if name not in step2.__dict__:
  540. raise NotImplementedError()
  541. return inspect.getattr_static(step2, name)
  542. def reconstruct(self, codegen):
  543. codegen(self.obj)
  544. return codegen.create_load_attrs(self.name)
  545. def call_function(
  546. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  547. ) -> "VariableTracker":
  548. from .builder import wrap_fx_proxy
  549. # This variable is True when it corresponds to user code such as
  550. #
  551. # super().__torch_function__(...)
  552. #
  553. # and the super().__torch_function__ attribute resolves
  554. # to torch.Tensor.__torch_function__.
  555. is_original_tensor_torch_function = (
  556. self.name == "__torch_function__"
  557. and isinstance(self.obj, SuperVariable)
  558. # for now, only support one level of inheritance
  559. and len(self.obj.objvar.value.__mro__) > 1
  560. and self.obj.objvar.value.__mro__[1] == torch.Tensor
  561. )
  562. if is_original_tensor_torch_function:
  563. # Instead of tracing inside torch.Tensor.__torch_function__,
  564. # record the `call_function` or `call_method` call into the graph.
  565. from . import TorchVariable
  566. original_torch_or_getattr_variable = args[0]
  567. new_args = args[2].items
  568. new_kwargs = args[3].items
  569. options = VariableTracker.propagate(self, new_args, new_kwargs.values())
  570. # Disable __torch_function__ here to prevent the clone of the
  571. # example tensor from going into the override.
  572. with torch._C.DisableTorchFunctionSubclass():
  573. if isinstance(args[0], TorchVariable):
  574. return wrap_fx_proxy(
  575. tx=tx,
  576. proxy=tx.output.create_proxy(
  577. "call_function",
  578. original_torch_or_getattr_variable.value,
  579. *proxy_args_kwargs(new_args, new_kwargs),
  580. ),
  581. **options,
  582. )
  583. elif isinstance(args[0], GetAttrVariable):
  584. return wrap_fx_proxy(
  585. tx=tx,
  586. proxy=tx.output.create_proxy(
  587. "call_method",
  588. original_torch_or_getattr_variable.name,
  589. *proxy_args_kwargs(new_args, new_kwargs),
  590. ),
  591. **options,
  592. )
  593. else:
  594. unimplemented(
  595. f"GetAttrVariable.call_function original __torch_function__ {args}"
  596. )
  597. if isinstance(self.obj, AutogradFunctionVariable) and self.name == "apply":
  598. return self.obj.call_apply(tx, args, kwargs).add_options(self)
  599. # calling parent class‘s non classmethod from child class
  600. # https://github.com/pytorch/pytorch/issues/90558
  601. elif (
  602. isinstance(self.obj, variables.UserDefinedClassVariable)
  603. and len(args) > 0
  604. and issubclass(args[0].python_type(), self.obj.value)
  605. ):
  606. return SuperVariable(self.obj, args[0], True).call_method(
  607. tx, self.name, args[1:], kwargs
  608. )
  609. return self.obj.call_method(tx, self.name, args, kwargs).add_options(self)
  610. def call_method(
  611. self,
  612. tx,
  613. name,
  614. args: "List[VariableTracker]",
  615. kwargs: "Dict[str, VariableTracker]",
  616. ) -> "VariableTracker":
  617. if (
  618. name == "__len__"
  619. and isinstance(self.obj, InspectSignatureVariable)
  620. and self.name == "parameters"
  621. ):
  622. return variables.ConstantVariable(
  623. self.obj.inspected.num_parameters(),
  624. **VariableTracker.propagate(self, self.obj, self.obj.inspected),
  625. )
  626. return super().call_method(tx, name, args, kwargs)
  627. class PythonModuleVariable(VariableTracker):
  628. def __init__(self, value: types.ModuleType, **kwargs):
  629. super().__init__(**kwargs)
  630. self.value = value
  631. def python_type(self):
  632. return types.ModuleType
  633. class SkipFilesVariable(VariableTracker):
  634. def __init__(self, value, **kwargs):
  635. super().__init__(**kwargs)
  636. self.value = value
  637. def python_type(self):
  638. return type(self.value)
  639. def as_python_constant(self):
  640. return self.value
  641. def call_function(
  642. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  643. ) -> "VariableTracker":
  644. if inspect.getattr_static(self.value, "_torchdynamo_disable", False):
  645. unimplemented(f"call torch._dynamo.disable() wrapped function {self.value}")
  646. else:
  647. try:
  648. path = inspect.getfile(self.value)
  649. except TypeError:
  650. path = f"Builtin {self.value.__name__}"
  651. unimplemented(
  652. f"call_function {self.value.__qualname__} in skip_files {path}"
  653. )
  654. class TypingVariable(VariableTracker):
  655. def __init__(self, value, **kwargs):
  656. super().__init__(**kwargs)
  657. self.value = value
  658. def call_method(
  659. self,
  660. tx,
  661. name,
  662. args: "List[VariableTracker]",
  663. kwargs: "Dict[str, VariableTracker]",
  664. ) -> "VariableTracker":
  665. if name == "__getitem__" and len(args) == 1:
  666. return variables.ConstantVariable(
  667. self.value[args[0].as_python_constant()],
  668. **VariableTracker.propagate(self, args),
  669. )
  670. unimplemented("typing")
  671. def python_type(self):
  672. return type(self.value)
  673. def as_python_constant(self):
  674. return self.value
  675. class NumpyVariable(VariableTracker):
  676. """
  677. Wrapper around `numpy.*` for better error messages.
  678. """
  679. def __init__(self, value, **kwargs):
  680. super().__init__(**kwargs)
  681. self.value = value
  682. def call_function(
  683. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  684. ) -> "VariableTracker":
  685. unimplemented("numpy")
  686. def call_method(
  687. self,
  688. tx,
  689. name,
  690. args: "List[VariableTracker]",
  691. kwargs: "Dict[str, VariableTracker]",
  692. ) -> "VariableTracker":
  693. unimplemented("numpy")
  694. def python_type(self):
  695. return type(self.value)
  696. def as_python_constant(self):
  697. return self.value