user_defined.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. import collections
  2. import contextlib
  3. import functools
  4. import importlib
  5. import inspect
  6. import random
  7. import types
  8. from typing import Dict, List
  9. import torch.nn
  10. from .. import variables
  11. from ..exc import unimplemented
  12. from ..guards import GuardBuilder
  13. from ..source import AttrSource, ODictGetItemSource, RandomValueSource
  14. from ..utils import is_namedtuple_cls, namedtuple_fields
  15. from .base import MutableLocal, VariableTracker
  16. from .misc import NullContextVariable
  17. class UserDefinedVariable(VariableTracker):
  18. pass
  19. class UserDefinedClassVariable(UserDefinedVariable):
  20. def __init__(self, value, **kwargs):
  21. super().__init__(**kwargs)
  22. self.value = value
  23. def as_python_constant(self):
  24. return self.value
  25. def python_type(self):
  26. return type(self.value)
  27. def var_getattr(self, tx, name: str) -> "VariableTracker":
  28. from . import ConstantVariable
  29. from .builder import VariableBuilder
  30. options = VariableTracker.propagate(self)
  31. source = AttrSource(self.source, name) if self.source is not None else None
  32. try:
  33. obj = inspect.getattr_static(self.value, name)
  34. except AttributeError:
  35. obj = None
  36. if isinstance(obj, staticmethod):
  37. return variables.UserFunctionVariable(
  38. obj.__get__(self.value), source=source, **options
  39. )
  40. elif isinstance(obj, classmethod):
  41. return variables.UserMethodVariable(
  42. obj.__func__, self, source=source, **options
  43. )
  44. if name in getattr(self.value, "__dict__", {}) or ConstantVariable.is_literal(
  45. obj
  46. ):
  47. if source:
  48. return VariableBuilder(tx, source)(obj).add_options(options)
  49. elif ConstantVariable.is_literal(obj):
  50. return ConstantVariable(obj, **options)
  51. return super().var_getattr(tx, name)
  52. def call_method(
  53. self,
  54. tx,
  55. name,
  56. args: "List[VariableTracker]",
  57. kwargs: "Dict[str, VariableTracker]",
  58. ) -> "VariableTracker":
  59. if (
  60. name == "__subclasses__"
  61. and len(args) == 0
  62. and not kwargs
  63. and "__subclasses__" not in self.value.__dict__
  64. ):
  65. options = VariableTracker.propagate(self, args, kwargs.values())
  66. options["mutable_local"] = MutableLocal()
  67. subs_as_vars: List[VariableTracker] = list()
  68. for sub in self.value.__subclasses__():
  69. source = AttrSource(tx.import_source(sub.__module__), sub.__name__)
  70. subs_as_vars.append(
  71. variables.UserDefinedClassVariable(sub, source=source)
  72. )
  73. return variables.ListVariable(subs_as_vars, **options)
  74. return super().call_method(tx, name, args, kwargs)
  75. def call_function(
  76. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  77. ) -> "VariableTracker":
  78. from ..side_effects import SideEffects
  79. options = VariableTracker.propagate(self, args, kwargs.values())
  80. if self.value in (
  81. contextlib.nullcontext,
  82. torch.autograd.profiler.profile,
  83. ):
  84. return NullContextVariable(**options)
  85. elif is_namedtuple_cls(self.value):
  86. fields = namedtuple_fields(self.value)
  87. items = list(args)
  88. items.extend([None] * (len(fields) - len(items)))
  89. for name, value in kwargs.items():
  90. assert name in fields
  91. items[fields.index(name)] = value
  92. assert all(x is not None for x in items)
  93. return variables.NamedTupleVariable(
  94. items, self.value, **VariableTracker.propagate(self, items)
  95. )
  96. elif (
  97. inspect.getattr_static(self.value, "__new__", None) in (object.__new__,)
  98. and SideEffects.cls_supports_mutation_side_effects(self.value)
  99. and self.source
  100. ):
  101. var = tx.output.side_effects.track_object_new(
  102. self.source, self.value, UserDefinedObjectVariable, options
  103. )
  104. return var.add_options(var.call_method(tx, "__init__", args, kwargs))
  105. elif variables.DataClassVariable.is_matching_cls(self.value):
  106. options["mutable_local"] = MutableLocal()
  107. return variables.DataClassVariable.create(self.value, args, kwargs, options)
  108. return super().call_function(tx, args, kwargs)
  109. def const_getattr(self, tx, name):
  110. if name == "__name__":
  111. return self.value.__name__
  112. return super().const_getattr(tx, name)
  113. class UserDefinedObjectVariable(UserDefinedVariable):
  114. """
  115. Mostly objects of defined type. Catch-all for something where we only know the type.
  116. """
  117. def __init__(self, value, value_type=None, **kwargs):
  118. super().__init__(**kwargs)
  119. self.value = value
  120. self.value_type = value_type or type(value)
  121. assert type(value) is self.value_type
  122. def __str__(self):
  123. inner = self.value_type.__name__
  124. if inner in [
  125. "builtin_function_or_method",
  126. "getset_descriptor",
  127. "method_descriptor",
  128. "method",
  129. ]:
  130. inner = str(getattr(self.value, "__name__", None))
  131. return f"{self.__class__.__name__}({inner})"
  132. def python_type(self):
  133. return self.value_type
  134. @staticmethod
  135. @functools.lru_cache(None)
  136. def _supported_random_functions():
  137. fns = {
  138. random.random,
  139. random.randint,
  140. random.randrange,
  141. random.uniform,
  142. }
  143. return fns
  144. def call_method(
  145. self,
  146. tx,
  147. name,
  148. args: "List[VariableTracker]",
  149. kwargs: "Dict[str, VariableTracker]",
  150. ) -> "VariableTracker":
  151. from . import ConstantVariable, TupleVariable, UserMethodVariable
  152. options = VariableTracker.propagate(self, args, kwargs.values())
  153. if name not in getattr(self.value, "__dict__", {}):
  154. try:
  155. method = inspect.getattr_static(type(self.value), name)
  156. except AttributeError:
  157. method = None
  158. if method is object.__init__:
  159. return ConstantVariable(None, **options)
  160. if method is collections.OrderedDict.keys and self.source:
  161. # subclass of OrderedDict
  162. assert not (args or kwargs)
  163. keys = list(self.value.keys())
  164. assert all(map(ConstantVariable.is_literal, keys))
  165. return TupleVariable(
  166. [ConstantVariable(k, **options) for k in keys], **options
  167. ).add_guard(self.source.make_guard(GuardBuilder.ODICT_KEYS))
  168. if (
  169. method is collections.OrderedDict.items
  170. and isinstance(self.value, collections.OrderedDict)
  171. and self.source
  172. ):
  173. assert not (args or kwargs)
  174. items = []
  175. keys = self.call_method(tx, "keys", [], {})
  176. options = VariableTracker.propagate(self, args, kwargs.values(), keys)
  177. for key in keys.unpack_var_sequence(tx):
  178. items.append(
  179. TupleVariable(
  180. [key, self.odict_getitem(tx, key)],
  181. **options,
  182. )
  183. )
  184. return TupleVariable(items, **options)
  185. if method is collections.OrderedDict.__getitem__ and len(args) == 1:
  186. assert not kwargs
  187. return self.odict_getitem(tx, args[0])
  188. # check for methods implemented in C++
  189. if isinstance(method, types.FunctionType):
  190. source = (
  191. None
  192. if self.source is None
  193. else AttrSource(AttrSource(self.source, "__class__"), name)
  194. )
  195. # TODO(jansel): add a guard to check for monkey patching?
  196. return UserMethodVariable(
  197. method, self, source=source, **options
  198. ).call_function(tx, args, kwargs)
  199. return super().call_method(tx, name, args, kwargs)
  200. def is_supported_random(self):
  201. try:
  202. return self.value in self._supported_random_functions()
  203. except TypeError:
  204. # TypeError: unhashable type
  205. return False
  206. def call_function(
  207. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  208. ) -> "VariableTracker":
  209. from .builder import VariableBuilder
  210. if (
  211. self.is_supported_random()
  212. and all(k.is_python_constant() for k in args)
  213. and all(v.is_python_constant() for v in kwargs.values())
  214. ):
  215. args = [x.as_python_constant() for x in args]
  216. kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
  217. random_call_index = len(tx.random_calls)
  218. if random_call_index == 0:
  219. tx.output.initial_random_state = random.getstate()
  220. example_value = self.value(*args, **kwargs)
  221. source = RandomValueSource(random_call_index)
  222. tx.random_calls.append((self.value, args, kwargs))
  223. return VariableBuilder(tx, source).wrap_unspecialized_primitive(
  224. example_value
  225. )
  226. return super().call_function(tx, args, kwargs)
  227. def _check_for_getattribute(self):
  228. try:
  229. if isinstance(
  230. inspect.getattr_static(type(self.value), "__getattribute__"),
  231. types.FunctionType,
  232. ):
  233. unimplemented("UserDefinedObjectVariable with custom __getattribute__")
  234. except AttributeError:
  235. pass
  236. def _check_for_getattr(self):
  237. try:
  238. getattr_fn = inspect.getattr_static(type(self.value), "__getattr__")
  239. except AttributeError:
  240. getattr_fn = None
  241. if getattr_fn is torch.nn.Module.__getattr__:
  242. # ignore this case of getattr
  243. getattr_fn = None
  244. return getattr_fn
  245. def _getattr_static(self, name):
  246. if (
  247. isinstance(self.value, torch.nn.Module)
  248. or "__slots__" in self.value.__class__.__dict__
  249. ):
  250. # getattr_static doesn't work on these
  251. subobj = getattr(self.value, name)
  252. else:
  253. subobj = inspect.getattr_static(self.value, name)
  254. return subobj
  255. def var_getattr(self, tx, name):
  256. from . import ConstantVariable
  257. from .builder import VariableBuilder
  258. options = VariableTracker.propagate(self)
  259. value = self.value
  260. source = AttrSource(self.source, name) if self.source else None
  261. self._check_for_getattribute()
  262. getattr_fn = self._check_for_getattr()
  263. try:
  264. subobj = self._getattr_static(name)
  265. except AttributeError:
  266. subobj = None
  267. if isinstance(getattr_fn, types.FunctionType):
  268. return variables.UserMethodVariable(
  269. getattr_fn, self, source=source, **options
  270. ).call_function(tx, [ConstantVariable(name)], {})
  271. elif getattr_fn is not None:
  272. unimplemented("UserDefined with non-function __getattr__")
  273. if isinstance(subobj, property):
  274. return variables.UserMethodVariable(
  275. subobj.fget, self, source=source, **options
  276. ).call_function(tx, [], {})
  277. elif isinstance(subobj, staticmethod):
  278. return variables.UserFunctionVariable(
  279. subobj.__get__(self.value), source=source, **options
  280. )
  281. elif isinstance(subobj, classmethod):
  282. return variables.UserMethodVariable(
  283. subobj.__func__, self, source=source, **options
  284. )
  285. elif isinstance(subobj, types.FunctionType):
  286. return variables.UserMethodVariable(subobj, self, source=source, **options)
  287. if (
  288. name in getattr(value, "__dict__", {})
  289. or ConstantVariable.is_literal(subobj)
  290. or isinstance(
  291. subobj,
  292. (
  293. torch.Tensor,
  294. torch.nn.Module,
  295. ),
  296. )
  297. ):
  298. if source:
  299. return VariableBuilder(tx, source)(subobj).add_options(options)
  300. elif ConstantVariable.is_literal(subobj):
  301. return ConstantVariable(subobj, **options)
  302. if (
  303. name not in getattr(value, "__dict__", {})
  304. and type(value).__module__.startswith("torch.")
  305. and "torch.optim" not in type(value).__module__
  306. and not callable(value)
  307. ):
  308. if not source:
  309. assert getattr(
  310. importlib.import_module(type(value).__module__),
  311. type(value).__name__,
  312. ) is type(value)
  313. source = AttrSource(
  314. AttrSource(
  315. tx.import_source(type(value).__module__), type(value).__name__
  316. ),
  317. name,
  318. )
  319. return VariableBuilder(tx, source)(subobj).add_options(options)
  320. options["source"] = source
  321. if isinstance(
  322. subobj,
  323. (
  324. torch.distributions.constraints._Interval,
  325. torch.distributions.constraints._Real,
  326. torch.distributions.constraints.Constraint,
  327. ),
  328. ):
  329. return UserDefinedObjectVariable(subobj, **options)
  330. if name == "__class__":
  331. return UserDefinedClassVariable(type(self.value), **options)
  332. return variables.GetAttrVariable(self, name, **options)
  333. def call_hasattr(self, tx, name: str) -> "VariableTracker":
  334. if not self.source:
  335. unimplemented("hasattr no source")
  336. options = VariableTracker.propagate(self)
  337. options["guards"].add(
  338. AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
  339. )
  340. if self._check_for_getattribute() or self._check_for_getattr():
  341. unimplemented("hasattr with custom __getattr__")
  342. try:
  343. self._getattr_static(name)
  344. return variables.ConstantVariable(True, **options)
  345. except AttributeError:
  346. return variables.ConstantVariable(False, **options)
  347. def odict_getitem(self, tx, key):
  348. from .builder import VariableBuilder
  349. return VariableBuilder(
  350. tx,
  351. ODictGetItemSource(self.source, key.as_python_constant()),
  352. )(
  353. collections.OrderedDict.__getitem__(self.value, key.as_python_constant())
  354. ).add_options(
  355. key, self
  356. )