nn_module.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620
  1. import functools
  2. import inspect
  3. import itertools
  4. import types
  5. from contextlib import contextmanager
  6. from typing import Dict, List
  7. import torch.nn
  8. from .. import skipfiles, variables
  9. from ..allowed_functions import is_allowed
  10. from ..exc import RestartAnalysis, unimplemented
  11. from ..guards import GuardBuilder
  12. from ..mutation_guard import GenerationTracker
  13. from ..source import AttrSource, GetItemSource, NNModuleSource, NotNNModuleSource
  14. from ..utils import (
  15. is_lazy_module,
  16. is_safe_constant,
  17. istensor,
  18. istype,
  19. proxy_args_kwargs,
  20. )
  21. from .base import MutableLocal, typestr, VariableTracker
  22. from .functions import invoke_and_store_as_constant
  23. from .lists import SliceVariable
  24. from .user_defined import UserDefinedObjectVariable
  25. class NNModuleVariable(VariableTracker):
  26. _nonvar_fields = ["module_type", "module_key"]
  27. def __init__(self, module_type: type, module_key: str, **kwargs):
  28. super().__init__(**kwargs)
  29. self.module_type = module_type
  30. self.module_key = module_key
  31. assert self.source
  32. def python_type(self):
  33. return self.module_type
  34. def _wrap_submodule(self, tx, source, submod, *key_extra, **options):
  35. return
  36. def unpack_var_sequence(self, tx):
  37. # implement list/iter/tuple/etc calls
  38. base = tx.output.get_submodule(self.module_key)
  39. options = VariableTracker.propagate([self])
  40. assert isinstance(
  41. base, (torch.nn.ModuleList, torch.nn.ParameterList, torch.nn.Sequential)
  42. ), typestr(base)
  43. assert self.source
  44. result = []
  45. for idx, submod in enumerate(base):
  46. result.append(
  47. tx.output.register_attr_or_module(
  48. submod,
  49. self.module_key,
  50. idx,
  51. source=NNModuleSource(GetItemSource(self.source, idx)),
  52. **options,
  53. )
  54. )
  55. return result
  56. def call_hasattr(self, tx, name: str) -> "VariableTracker":
  57. options = VariableTracker.propagate(self)
  58. mod = tx.output.get_submodule(self.module_key)
  59. result = hasattr(mod, name)
  60. return variables.ConstantVariable(result, **options).add_guard(
  61. NNModuleSource(AttrSource(self.source, name)).make_guard(
  62. GuardBuilder.HASATTR
  63. )
  64. )
  65. def is_training(self, tx):
  66. mod = tx.output.get_submodule(self.module_key)
  67. return getattr(mod, "training", False)
  68. def convert_to_unspecialized(self, tx):
  69. """Restart analysis treating this module as an UnspecializedNNModuleVariable"""
  70. mod = tx.output.get_submodule(self.module_key)
  71. GenerationTracker.tag(mod)
  72. # Mark the class dynamic unless its module initialization
  73. if tx.f_code.co_name != "__init__":
  74. GenerationTracker.mark_class_dynamic(type(mod))
  75. raise RestartAnalysis()
  76. def var_getattr(self, tx, name):
  77. from .builder import VariableBuilder
  78. options = VariableTracker.propagate(self)
  79. guards = options.get("guards", set())
  80. if self.source:
  81. source = AttrSource(self.source, name)
  82. options["source"] = source
  83. else:
  84. source = None
  85. base = tx.output.get_submodule(self.module_key)
  86. base_dict = object.__getattribute__(base, "__dict__")
  87. object_member = True
  88. all_class_attribute_names = set()
  89. for x in inspect.getmro(base.__class__):
  90. all_class_attribute_names.update(x.__dict__.keys())
  91. if not self.source:
  92. unimplemented("GETATTR with no source")
  93. if name in base_dict:
  94. subobj = base_dict[name]
  95. elif (
  96. "_modules" in base_dict
  97. and name in base_dict["_modules"]
  98. and name not in all_class_attribute_names
  99. ):
  100. subobj = base_dict["_modules"][name]
  101. elif "_parameters" in base_dict and name in base_dict["_parameters"]:
  102. subobj = base_dict["_parameters"][name]
  103. elif "_buffers" in base_dict and name in base_dict["_buffers"]:
  104. subobj = base_dict["_buffers"][name]
  105. else:
  106. subobj = inspect.getattr_static(base, name)
  107. object_member = False
  108. if name == "__class__" and not object_member:
  109. return variables.UserDefinedClassVariable(base.__class__, **options)
  110. if object_member:
  111. return VariableBuilder(tx, NNModuleSource(source))(subobj)
  112. else:
  113. if istype(subobj, property):
  114. return variables.UserFunctionVariable(
  115. subobj.fget,
  116. guards=guards,
  117. source=source,
  118. ).call_function(tx, [(self)], {})
  119. elif istype(subobj, classmethod):
  120. return variables.UserMethodVariable(
  121. subobj.__func__,
  122. variables.UserDefinedObjectVariable(type(base), guards=guards),
  123. **options,
  124. )
  125. elif istype(subobj, staticmethod):
  126. return variables.UserFunctionVariable(subobj.__get__(base), **options)
  127. elif istype(subobj, types.FunctionType):
  128. return variables.UserMethodVariable(subobj, self, **options)
  129. elif is_safe_constant(subobj) or istensor(subobj):
  130. # Support possibly common cases of class members
  131. return VariableBuilder(tx, NNModuleSource(source))(subobj)
  132. else:
  133. unimplemented(f"class property {typestr(base)} {typestr(subobj)}")
  134. return variables.GetAttrVariable(self, name, **options)
  135. def call_function(
  136. self,
  137. tx,
  138. args: "List[VariableTracker]",
  139. kwargs: "Dict[str, VariableTracker]",
  140. ) -> "VariableTracker":
  141. options = VariableTracker.propagate(self, args, kwargs.values())
  142. mod = tx.output.get_submodule(self.module_key)
  143. @contextmanager
  144. def record_nn_module_stack():
  145. try:
  146. tx.nn_module_stack[self.module_key] = type(mod)
  147. yield
  148. finally:
  149. del tx.nn_module_stack[self.module_key]
  150. with record_nn_module_stack():
  151. is_lazy = is_lazy_module(mod)
  152. if (
  153. isinstance(mod, torch.nn.Sequential)
  154. and mod.__class__.forward is torch.nn.Sequential.forward
  155. ):
  156. # unroll Sequential()
  157. assert not kwargs
  158. (arg,) = args
  159. for idx, submod in enumerate(mod):
  160. tx.call_function(
  161. tx.output.register_attr_or_module(
  162. submod,
  163. self.module_key,
  164. idx,
  165. source=NNModuleSource(GetItemSource(self.source, idx)),
  166. **options,
  167. ),
  168. [arg],
  169. {},
  170. )
  171. arg = tx.pop()
  172. return arg
  173. elif is_allowed(mod.__class__):
  174. # The module type will change after it is called
  175. if is_lazy:
  176. self.module_type = mod.cls_to_become
  177. from .builder import wrap_fx_proxy
  178. return wrap_fx_proxy(
  179. tx=tx,
  180. proxy=tx.output.create_proxy(
  181. "call_module",
  182. self.module_key,
  183. *proxy_args_kwargs(args, kwargs),
  184. ),
  185. **options,
  186. )
  187. else:
  188. # for lazy modules, run the pre-hooks which will update the type
  189. # TODO mlazos: we don't fully support all of the hooks that exist,
  190. # so restrict using __call__ only to lazy modules for now
  191. assert self.source, (
  192. "Must provide a valid source in order to inline, "
  193. "since inlined function may have default args which must be guarded."
  194. )
  195. if is_lazy:
  196. if istype(mod.__call__, types.FunctionType):
  197. fn = mod.__call__
  198. fn_source = AttrSource(self.source, "__call__")
  199. else:
  200. assert istype(mod.__call__, types.MethodType)
  201. fn = mod.__call__.__func__
  202. fn_source = AttrSource(
  203. AttrSource(self.source, "__call__"), "__func__"
  204. )
  205. args = [self] + args
  206. else:
  207. if istype(mod.forward, types.FunctionType):
  208. fn = mod.forward
  209. fn_source = AttrSource(self.source, "forward")
  210. else:
  211. assert istype(mod.forward, types.MethodType)
  212. fn = mod.forward.__func__
  213. fn_source = AttrSource(
  214. AttrSource(self.source, "forward"), "__func__"
  215. )
  216. args = [self] + args
  217. options["source"] = fn_source
  218. return tx.inline_user_function_return(
  219. variables.UserFunctionVariable(fn, **options),
  220. args,
  221. kwargs,
  222. )
  223. def call_method(
  224. self,
  225. tx,
  226. name,
  227. args: "List[VariableTracker]",
  228. kwargs: "Dict[str, VariableTracker]",
  229. constant=False,
  230. ) -> "VariableTracker":
  231. from . import ConstantVariable, ListIteratorVariable, TupleVariable
  232. options = VariableTracker.propagate(self, args, kwargs.values())
  233. key = self.module_key
  234. module = tx.output.get_submodule(key)
  235. if name == "forward":
  236. return self.call_function(tx, args, kwargs)
  237. if name == "_check_input_dim" and skipfiles.is_torch_inline_allowed(
  238. inspect.getfile(module.__class__._check_input_dim)
  239. ):
  240. return ConstantVariable(True, **options)
  241. if name == "_get_item_by_idx":
  242. assert args[1].is_python_constant()
  243. assert isinstance(args[0], TupleVariable)
  244. mod_var = args[0].items[args[1].value]
  245. key = mod_var.module_key
  246. submod = tx.output.get_submodule(key)
  247. return tx.output.register_attr_or_module(
  248. submod,
  249. key,
  250. key,
  251. source=NNModuleSource(GetItemSource(self.source, key)),
  252. **options,
  253. )
  254. if constant:
  255. fn = getattr(module, name)
  256. name = f"{module.__class__.__name__}_{name}_result"
  257. return invoke_and_store_as_constant(tx, fn, name, options, args, kwargs)
  258. def assert_all_args_kwargs_const():
  259. if not all(
  260. x.is_python_constant() for x in itertools.chain(args, kwargs.values())
  261. ):
  262. raise unimplemented(f"non-const NNModule method {name}")
  263. def get_kwargs(*names):
  264. assert_all_args_kwargs_const()
  265. fn = getattr(module, name)
  266. bound_args = inspect.signature(fn).bind(
  267. *([x.as_python_constant() for x in args]),
  268. **{k: v.as_python_constant() for k, v in kwargs.items()},
  269. )
  270. bound_args.apply_defaults()
  271. bound_args = bound_args.arguments
  272. return {k: bound_args[k] for k in names}
  273. def wrap_values(items):
  274. result = []
  275. for name, submod in items:
  276. result.append(
  277. tx.output.register_attr_or_module(
  278. submod,
  279. key,
  280. name,
  281. source=NNModuleSource(gen_source(self.source, name)),
  282. **options,
  283. )
  284. )
  285. return ListIteratorVariable(result, mutable_local=MutableLocal(), **options)
  286. def named_embed(name, obj):
  287. return TupleVariable(
  288. [
  289. ConstantVariable(name, **options),
  290. tx.output.register_attr_or_module(
  291. obj,
  292. key,
  293. name,
  294. source=NNModuleSource(gen_source(self.source, name)),
  295. **options,
  296. ),
  297. ]
  298. )
  299. def gen_source(source, name):
  300. name_split = name.split(".")
  301. if name_split[0] == "":
  302. return source
  303. while len(name_split) > 0:
  304. x = name_split.pop(0)
  305. source = AttrSource(source, x)
  306. return source
  307. if name == "children":
  308. assert not (args or kwargs)
  309. return wrap_values(module.named_children())
  310. elif name == "named_parameters":
  311. result = []
  312. for name, param in module.named_parameters(
  313. **get_kwargs("prefix", "recurse")
  314. ):
  315. result.append(named_embed(name, param))
  316. return ListIteratorVariable(result, mutable_local=MutableLocal(), **options)
  317. elif name == "named_buffers":
  318. result = []
  319. for name, buffer in module.named_buffers(
  320. **get_kwargs("prefix", "recurse", "remove_duplicate")
  321. ):
  322. result.append(named_embed(name, buffer))
  323. return ListIteratorVariable(result, mutable_local=MutableLocal(), **options)
  324. elif name == "named_modules":
  325. result = []
  326. for name, submod in module.named_modules(
  327. **get_kwargs("memo", "prefix", "remove_duplicate")
  328. ):
  329. result.append(named_embed(name, submod))
  330. return ListIteratorVariable(result, mutable_local=MutableLocal(), **options)
  331. elif name == "modules":
  332. return wrap_values(module.named_modules())
  333. elif name == "parameters":
  334. return wrap_values(module.named_parameters(**get_kwargs("recurse")))
  335. elif name == "keys":
  336. assert not (args or kwargs)
  337. result = []
  338. for name in module.keys():
  339. result.append(ConstantVariable(name, **options))
  340. return ListIteratorVariable(result, mutable_local=MutableLocal(), **options)
  341. elif name == "values":
  342. assert not (args or kwargs)
  343. return wrap_values(module.items())
  344. elif name == "items":
  345. assert not (args or kwargs)
  346. result = []
  347. for name, submod in module.items():
  348. result.append(named_embed(name, submod))
  349. return ListIteratorVariable(result, mutable_local=MutableLocal(), **options)
  350. elif name == "__len__":
  351. assert not (args or kwargs)
  352. return ConstantVariable(len(module), **options)
  353. elif (
  354. name == "__contains__"
  355. and isinstance(module, (torch.nn.ModuleDict, torch.nn.ParameterDict))
  356. and args
  357. and args[0].is_python_constant()
  358. ):
  359. return ConstantVariable(
  360. args[0].as_python_constant() in module._modules, **options
  361. )
  362. elif name == "__getitem__":
  363. assert not kwargs and len(args) == 1
  364. assert type(module).__getitem__ in (
  365. torch.nn.ModuleDict.__getitem__,
  366. torch.nn.ModuleList.__getitem__,
  367. torch.nn.ParameterList.__getitem__,
  368. torch.nn.Sequential.__getitem__,
  369. ), typestr(module)
  370. assert self.source
  371. if isinstance(args[0], SliceVariable):
  372. # Build a TupleVariable of NNModules
  373. result = []
  374. submods = []
  375. # Turn the slice into the list of integers
  376. keys = list(range(len(module)))[args[0].as_python_constant()]
  377. for idx, submod in enumerate(module[args[0].as_python_constant()]):
  378. key = keys[idx]
  379. src = NNModuleSource(GetItemSource(self.source, key))
  380. result.append(
  381. tx.output.register_attr_or_module(
  382. submod,
  383. key,
  384. source=src,
  385. **options,
  386. )
  387. )
  388. submods.append(submod)
  389. new_module = torch.nn.Sequential(*submods)
  390. new_module_variable = tx.output.register_attr_or_module(
  391. new_module,
  392. f"{self}.__getitem__(slice)",
  393. source=NNModuleSource(
  394. GetItemSource(self.source, args[0].as_python_constant())
  395. ),
  396. **options,
  397. )
  398. return new_module_variable
  399. key = args[0].as_python_constant()
  400. submod = module[key]
  401. return tx.output.register_attr_or_module(
  402. submod,
  403. key,
  404. args[0].as_python_constant(),
  405. source=NNModuleSource(GetItemSource(self.source, key)),
  406. **options,
  407. )
  408. elif name == "_get_abs_string_index":
  409. # Inline the function
  410. fn = getattr(module, name).__func__
  411. src = AttrSource(AttrSource(self.source, name), "__func__")
  412. return tx.inline_user_function_return(
  413. variables.UserFunctionVariable(fn, source=src, **options),
  414. [self] + args,
  415. kwargs,
  416. )
  417. # A loose heuristic, but seems to be generally good before we drop into the
  418. # manual handling of inputs
  419. elif (
  420. name in module.__class__.__dict__
  421. and callable(module.__class__.__dict__[name])
  422. and all(
  423. isinstance(x, variables.TensorVariable)
  424. for x in itertools.chain(args, kwargs.values())
  425. )
  426. ):
  427. # TODO(voz): Refactor this into a generic as_proxy() for nn module
  428. # We use variations of this pattern in a few places now.
  429. def make_attr(name):
  430. node = tx.output.create_proxy(
  431. "get_attr",
  432. name,
  433. tuple(),
  434. {},
  435. )
  436. return node
  437. # Bind in self
  438. tx.output.register_attr_or_module(
  439. module,
  440. self.module_key,
  441. self.module_key,
  442. source=NNModuleSource(GetItemSource(self.source, self.module_key)),
  443. **options,
  444. )
  445. proxy_for_mod = make_attr(self.module_key)
  446. proxy_for_mod.node.meta["example_value"] = module
  447. proxy_args, proxy_kwargs = proxy_args_kwargs(args, kwargs)
  448. from .builder import wrap_fx_proxy
  449. return wrap_fx_proxy(
  450. tx=tx,
  451. proxy=tx.output.create_proxy(
  452. "call_method",
  453. name,
  454. args=(proxy_for_mod, *proxy_args),
  455. kwargs=proxy_kwargs,
  456. ),
  457. **options,
  458. )
  459. else:
  460. return super().call_method(tx, name, args, kwargs)
  461. class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
  462. """
  463. The above class will specialize on the id() of a module and place
  464. parameters on the torch.fx.GraphModule. Giving one graph per
  465. module instance. This version treats nn.Modules() like other user
  466. defined objects and will pass parameters into the FX graph as inputs.
  467. Giving one graph per module class.
  468. """
  469. def __init__(self, value, **kwargs):
  470. super().__init__(value=value, **kwargs)
  471. if self.source and self.source.is_nn_module():
  472. # force guard checks even when `not config.guard_nn_modules``
  473. self.source = NotNNModuleSource(self.source)
  474. @staticmethod
  475. @functools.lru_cache(None)
  476. def _nn_module_method_ids():
  477. return {
  478. id(x.__code__)
  479. for x in torch.nn.Module.__dict__.values()
  480. if hasattr(x, "__code__")
  481. }
  482. def unpack_var_sequence(self, tx):
  483. from .builder import VariableBuilder
  484. try:
  485. fn = inspect.getattr_static(self.value_type, "__iter__")
  486. except AttributeError as e:
  487. raise NotImplementedError from e
  488. if fn in (
  489. torch.nn.ModuleList.__iter__,
  490. torch.nn.ParameterList.__iter__,
  491. torch.nn.Sequential.__iter__,
  492. ):
  493. assert self.source
  494. return [
  495. VariableBuilder(tx, source=GetItemSource(self.source, idx))(
  496. item
  497. ).add_options(self)
  498. for idx, item in enumerate(self.value)
  499. ]
  500. return super().unpack_var_sequence(tx)
  501. def call_function(
  502. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  503. ) -> "VariableTracker":
  504. options = VariableTracker.propagate(self, args, kwargs.values())
  505. # TODO mlazos: only support __call__ for lazy modules
  506. # until we can support a larger swath of python
  507. if is_lazy_module(self.value):
  508. fn = self.value_type.__call__
  509. source = AttrSource(AttrSource(self.source, "__class__"), "__call__")
  510. else:
  511. fn = self.value_type.forward
  512. source = AttrSource(AttrSource(self.source, "__class__"), "forward")
  513. return variables.UserFunctionVariable(
  514. fn, source=source, **options
  515. ).call_function(tx, [self] + list(args), kwargs)
  516. def call_method(
  517. self,
  518. tx,
  519. name,
  520. args: "List[VariableTracker]",
  521. kwargs: "Dict[str, VariableTracker]",
  522. ) -> "VariableTracker":
  523. from .builder import VariableBuilder
  524. options = VariableTracker.propagate(self, args, kwargs.values())
  525. if name not in getattr(self.value, "__dict__", {}):
  526. try:
  527. method = inspect.getattr_static(type(self.value), name)
  528. except AttributeError:
  529. method = None
  530. if method is torch.nn.Module.parameters:
  531. assert not args or kwargs
  532. options["guards"].add(
  533. self.source.make_guard(GuardBuilder.NN_MODULE_PARAM_NAMES)
  534. )
  535. items = []
  536. for name, value in self.value.named_parameters():
  537. items.append(
  538. VariableBuilder(tx, AttrSource(self.source, name))(
  539. value
  540. ).add_options(options)
  541. )
  542. return variables.ListIteratorVariable(
  543. items, mutable_local=MutableLocal(), **options
  544. )
  545. elif isinstance(method, staticmethod):
  546. source = AttrSource(
  547. AttrSource(AttrSource(self.source, "__class__"), name), "__func__"
  548. )
  549. return tx.inline_user_function_return(
  550. variables.UserFunctionVariable(
  551. method.__func__, source=source, **options
  552. ),
  553. args,
  554. kwargs,
  555. )
  556. if id(method.__code__) in self._nn_module_method_ids():
  557. unimplemented(f"UnspecializedNNModuleVariable missing {name}")
  558. return super().call_method(tx, name, args, kwargs)