base.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. import collections
  2. from typing import Any, Callable, Dict, List, Optional, Set
  3. from .. import variables
  4. from ..exc import unimplemented
  5. from ..source import AttrSource, Source
  6. from ..utils import dict_values, identity, istype, odict_values
  7. class MutableLocal:
  8. """
  9. Marker used to indicate this (list, iter, etc) was constructed in
  10. local scope and can be mutated safely in analysis without leaking
  11. state.
  12. """
  13. def __hash__(self):
  14. return id(self)
  15. def __eq__(self, other):
  16. return self is other
  17. # metaclass to call post_init
  18. class HasPostInit(type):
  19. def __call__(cls, *args, **kwargs):
  20. obj = type.__call__(cls, *args, **kwargs)
  21. obj.__post_init__(*args, **kwargs)
  22. return obj
  23. class VariableTracker(metaclass=HasPostInit):
  24. """
  25. Base class for tracked locals and stack values
  26. VariableTracker instances are immutable and should be copied in
  27. order to change them.
  28. """
  29. # fields to leave unmodified in apply()
  30. _nonvar_fields = ["value"]
  31. @staticmethod
  32. def propagate(*vars: List[List["VariableTracker"]]):
  33. """Combine the guards from many VariableTracker into **kwargs for a new instance"""
  34. guards = set()
  35. def visit(var):
  36. if type(var) in (list, tuple, dict_values, odict_values):
  37. for i in var:
  38. visit(i)
  39. else:
  40. assert isinstance(var, VariableTracker), typestr(var)
  41. guards.update(var.guards)
  42. visit(vars)
  43. return {
  44. "guards": guards,
  45. }
  46. def clone(self, **kwargs):
  47. """Shallow copy with some (optional) changes"""
  48. args = dict(self.__dict__)
  49. args.update(kwargs)
  50. return self.__class__(**args)
  51. @classmethod
  52. def copy(cls, value):
  53. """Deeper (but not full) copy, leaving FX and user objects alone"""
  54. return cls.apply(identity, value)
  55. @classmethod
  56. def apply(
  57. cls,
  58. fn: Callable[["VariableTracker"], "VariableTracker"],
  59. value,
  60. cache=None,
  61. skip_fn=lambda _: False, # Whether we should skip applying to this var
  62. ):
  63. """
  64. Walk this object and call fn on all the VariableTracker
  65. instances to produce a new VariableTracker with the results.
  66. """
  67. if cache is None:
  68. cache = dict()
  69. idx = id(value)
  70. if idx in cache:
  71. return cache[idx][0]
  72. if isinstance(value, VariableTracker):
  73. if not skip_fn(value):
  74. updated_dict = dict(value.__dict__)
  75. for key in updated_dict.keys():
  76. if key not in value._nonvar_fields:
  77. updated_dict[key] = cls.apply(
  78. fn, updated_dict[key], cache, skip_fn
  79. )
  80. result = fn(value.clone(**updated_dict))
  81. else:
  82. result = fn(value)
  83. elif istype(value, list):
  84. result = [cls.apply(fn, v, cache, skip_fn) for v in value]
  85. elif istype(value, tuple):
  86. result = tuple(cls.apply(fn, v, cache, skip_fn) for v in value)
  87. elif istype(value, collections.OrderedDict):
  88. result = collections.OrderedDict(
  89. cls.apply(fn, v, cache, skip_fn) for v in value.items()
  90. )
  91. elif istype(value, dict):
  92. result = {
  93. k: cls.apply(fn, v, cache, skip_fn) for k, v in list(value.items())
  94. }
  95. else:
  96. result = value
  97. # save `value` to keep it alive and ensure id() isn't reused
  98. cache[idx] = (result, value)
  99. return result
  100. def add_guard(self, guard):
  101. return self.clone(guards=set.union(self.guards, {guard}))
  102. def add_guards(self, guards):
  103. if guards is None:
  104. return self
  105. assert isinstance(guards, set)
  106. return self.clone(guards=set.union(self.guards, guards))
  107. def add_options(self, options, *more):
  108. if more:
  109. return self.add_options(options).add_options(*more)
  110. if isinstance(options, VariableTracker):
  111. return self.add_guards(options.guards)
  112. assert isinstance(options, dict)
  113. return self.add_guards(options.get("guards", set()))
  114. def __str__(self):
  115. return f"{self.__class__.__name__}()"
  116. def __repr__(self):
  117. return str(self)
  118. def python_type(self):
  119. raise NotImplementedError(f"{self} has no type")
  120. def as_python_constant(self):
  121. """For constants"""
  122. raise NotImplementedError(f"{self} is not a constant")
  123. def is_python_constant(self):
  124. try:
  125. self.as_python_constant()
  126. return True
  127. except NotImplementedError:
  128. return False
  129. def as_specialized(self, tx):
  130. """
  131. For specialized variables, return itself,
  132. For unspecialized variables, convert to constant variable and return.
  133. """
  134. return self
  135. def can_make_guard(self):
  136. try:
  137. self.make_guard(None)
  138. return True
  139. except NotImplementedError:
  140. return False
  141. def make_guard(self, fn):
  142. if self.source:
  143. return self.source.make_guard(fn)
  144. raise NotImplementedError()
  145. def replace_guards(self, guards, *fns):
  146. name = self.source.name()
  147. new_guards = {g for g in (guards or []) if g.name != name}
  148. new_guards.update(self.source.make_guard(fn) for fn in fns)
  149. return new_guards
  150. def const_getattr(self, tx, name: str) -> Any:
  151. """getattr(self, name) returning a python constant"""
  152. raise NotImplementedError()
  153. def var_getattr(self, tx, name: str) -> "VariableTracker":
  154. """getattr(self, name) returning a new variable"""
  155. options = VariableTracker.propagate(self)
  156. value = self.const_getattr(tx, name)
  157. if not variables.ConstantVariable.is_literal(value):
  158. raise NotImplementedError()
  159. if self.source:
  160. options["source"] = AttrSource(self.source, name)
  161. return variables.ConstantVariable(value, **options)
  162. def is_proxy(self):
  163. try:
  164. self.as_proxy()
  165. return True
  166. except NotImplementedError:
  167. return False
  168. def as_proxy(self):
  169. raise NotImplementedError(str(self))
  170. def reconstruct(self, codegen):
  171. raise NotImplementedError()
  172. def unpack_var_sequence(self, tx):
  173. raise NotImplementedError()
  174. def has_unpack_var_sequence(self, tx):
  175. try:
  176. self.unpack_var_sequence(tx)
  177. return True
  178. except NotImplementedError:
  179. return False
  180. def num_parameters(self):
  181. unimplemented(f"num_parameters: {self}")
  182. def call_hasattr(self, tx, name: str) -> "VariableTracker":
  183. unimplemented(f"hasattr: {repr(self)}")
  184. def call_function(
  185. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  186. ) -> "VariableTracker":
  187. unimplemented(f"call_function {self} {args} {kwargs}")
  188. def call_method(
  189. self,
  190. tx,
  191. name,
  192. args: "List[VariableTracker]",
  193. kwargs: "Dict[str, VariableTracker]",
  194. ) -> "VariableTracker":
  195. if name == "__len__" and self.has_unpack_var_sequence(tx):
  196. assert not (args or kwargs)
  197. return variables.ConstantVariable(
  198. len(self.unpack_var_sequence(tx)), **VariableTracker.propagate(self)
  199. )
  200. elif (
  201. name == "__getattr__"
  202. and len(args) == 1
  203. and args[0].is_python_constant()
  204. and not kwargs
  205. ):
  206. return self.var_getattr(tx, args[0].as_python_constant()).add_options(
  207. self, args[0]
  208. )
  209. raise unimplemented(f"call_method {self} {name} {args} {kwargs}")
  210. def __init__(
  211. self,
  212. guards: Optional[Set] = None,
  213. source: Source = None,
  214. mutable_local: MutableLocal = None,
  215. recursively_contains: Optional[Set] = None,
  216. ):
  217. super().__init__()
  218. self.guards = guards or set()
  219. self.source = source
  220. self.mutable_local = mutable_local
  221. self.recursively_contains = (
  222. recursively_contains # provides hint to replace_all when replacing vars
  223. )
  224. def __post_init__(self, *args, **kwargs):
  225. if self.recursively_contains is None:
  226. self.recursively_contains = set()
  227. def aggregate_mutables(var):
  228. self.recursively_contains.update(var.recursively_contains)
  229. if var.mutable_local is not None:
  230. self.recursively_contains.add(var.mutable_local)
  231. return var
  232. VariableTracker.apply(
  233. aggregate_mutables, self, skip_fn=lambda var: var is not self
  234. )
  235. assert None not in self.recursively_contains
  236. def typestr(*objs):
  237. if len(objs) == 1:
  238. (obj,) = objs
  239. if isinstance(obj, VariableTracker):
  240. return str(obj)
  241. else:
  242. return type(obj).__name__
  243. else:
  244. return " ".join(map(typestr, objs))