dicts.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. import collections
  2. import dataclasses
  3. import functools
  4. import inspect
  5. from typing import Dict, List
  6. from .. import variables
  7. from ..bytecode_transformation import create_instruction
  8. from ..eval_frame import skip_code
  9. from ..exc import unimplemented
  10. from ..source import AttrSource, GlobalWeakRefSource
  11. from ..utils import global_key_name, istensor
  12. from .base import MutableLocal, VariableTracker
  13. from .constant import ConstantVariable
  14. from .tensor import TensorVariable
  15. class ConstDictVariable(VariableTracker):
  16. def __init__(self, items, user_cls, recursively_contains=None, **kwargs):
  17. super().__init__(recursively_contains=recursively_contains, **kwargs)
  18. self.guards.update(VariableTracker.propagate(items.values())["guards"])
  19. self.items = items
  20. self.user_cls = user_cls
  21. def as_proxy(self):
  22. return {k: v.as_proxy() for k, v in self.items.items()}
  23. def as_python_constant(self):
  24. return {k: v.as_python_constant() for k, v in self.items.items()}
  25. def python_type(self):
  26. return self.user_cls
  27. def reconstruct(self, codegen):
  28. for key, value in self.items.items():
  29. if istensor(key):
  30. codegen.extend_output(
  31. [
  32. codegen.create_load_global(global_key_name(key), add=True),
  33. create_instruction("CALL_FUNCTION", 0),
  34. ]
  35. )
  36. else:
  37. codegen.append_output(codegen.create_load_const(key))
  38. codegen(self.items[key])
  39. return [create_instruction("BUILD_MAP", len(self.items))]
  40. def getitem_const(self, arg: VariableTracker):
  41. return self.items[ConstDictVariable.get_key(arg)].add_options(self, arg)
  42. def call_method(
  43. self,
  44. tx,
  45. name,
  46. args: "List[VariableTracker]",
  47. kwargs: "Dict[str, VariableTracker]",
  48. ) -> "VariableTracker":
  49. from . import ConstantVariable, TupleVariable
  50. options = VariableTracker.propagate(self, args, kwargs.values())
  51. val = self.items
  52. if name == "__getitem__":
  53. return self.getitem_const(args[0])
  54. elif name == "items":
  55. assert not (args or kwargs)
  56. return TupleVariable(
  57. [
  58. TupleVariable(
  59. [
  60. ConstDictVariable._key_to_var(
  61. tx,
  62. k,
  63. **options,
  64. ),
  65. v,
  66. ],
  67. **options,
  68. )
  69. for k, v in val.items()
  70. ],
  71. **options,
  72. )
  73. elif name == "keys":
  74. assert not (args or kwargs)
  75. return TupleVariable(
  76. [
  77. ConstDictVariable._key_to_var(
  78. tx,
  79. k,
  80. **options,
  81. )
  82. for k in val.keys()
  83. ],
  84. **options,
  85. )
  86. elif name == "values":
  87. assert not (args or kwargs)
  88. return TupleVariable(list(val.values()), **options)
  89. elif name == "__len__":
  90. assert not (args or kwargs)
  91. return ConstantVariable(len(self.items), **options)
  92. elif (
  93. name == "__setitem__"
  94. and args
  95. and ConstDictVariable.is_valid_key(args[0])
  96. and self.mutable_local
  97. ):
  98. assert not kwargs and len(args) == 2
  99. k = ConstDictVariable.get_key(args[0])
  100. if istensor(k):
  101. tx.store_dict_key(global_key_name(k), k)
  102. newval = collections.OrderedDict(val)
  103. newval[k] = args[1]
  104. new_rec_contains = self.recursively_contains.union(
  105. args[1].recursively_contains
  106. )
  107. if args[1].mutable_local is not None:
  108. new_rec_contains.add(args[1].mutable_local)
  109. return tx.replace_all(
  110. self,
  111. self.modifed(newval, new_rec_contains, **options),
  112. )
  113. elif (
  114. name in ("pop", "get")
  115. and args
  116. and ConstDictVariable.is_valid_key(args[0])
  117. and ConstDictVariable.get_key(args[0]) not in self.items
  118. and len(args) == 2
  119. ):
  120. # missing item, return the default value
  121. return args[1].add_options(options)
  122. elif (
  123. name == "pop"
  124. and args
  125. and ConstDictVariable.is_valid_key(args[0])
  126. and self.mutable_local
  127. ):
  128. newval = collections.OrderedDict(val)
  129. result = newval.pop(ConstDictVariable.get_key(args[0]))
  130. tx.replace_all(self, self.modifed(newval, None, **options))
  131. return result.add_options(options)
  132. elif (
  133. name == "update"
  134. and args
  135. and isinstance(args[0], ConstDictVariable)
  136. and self.mutable_local
  137. ):
  138. newval = collections.OrderedDict(val)
  139. newval.update(args[0].items)
  140. new_rec_contains = self.recursively_contains.union(
  141. args[0].recursively_contains
  142. )
  143. result = self.modifed(
  144. newval, recursively_contains=new_rec_contains, **options
  145. )
  146. return tx.replace_all(self, result)
  147. elif (
  148. name in ("get", "__getattr__")
  149. and args
  150. and ConstDictVariable.is_valid_key(args[0])
  151. and ConstDictVariable.get_key(args[0]) in self.items
  152. ):
  153. result = self.items[ConstDictVariable.get_key(args[0])]
  154. return result.add_options(options)
  155. elif (
  156. name == "__contains__" and args and ConstDictVariable.is_valid_key(args[0])
  157. ):
  158. return ConstantVariable(
  159. ConstDictVariable.get_key(args[0]) in self.items, **options
  160. )
  161. else:
  162. return super().call_method(tx, name, args, kwargs)
  163. def modifed(self, items, recursively_contains, **options):
  164. """a copy of self with different items"""
  165. return self.clone(
  166. items=items, recursively_contains=recursively_contains, **options
  167. )
  168. def unpack_var_sequence(self, tx):
  169. options = VariableTracker.propagate([self])
  170. val = self.items
  171. result = [ConstDictVariable._key_to_var(tx, k, **options) for k in val.keys()]
  172. return result
  173. @classmethod
  174. def get_key(cls, arg: VariableTracker):
  175. if isinstance(arg, TensorVariable) and arg.specialized_value is not None:
  176. return arg.specialized_value
  177. else:
  178. return arg.as_python_constant()
  179. @classmethod
  180. def is_valid_key(cls, key):
  181. return (
  182. key.is_python_constant()
  183. or isinstance(key, TensorVariable)
  184. and key.specialized_value is not None
  185. )
  186. @classmethod
  187. def _key_to_var(cls, tx, key, **options):
  188. from .builder import VariableBuilder
  189. if istensor(key):
  190. return VariableBuilder(tx, GlobalWeakRefSource(global_key_name(key)))(key)
  191. else:
  192. assert ConstantVariable.is_literal(key)
  193. return ConstantVariable(key, **options)
  194. class DefaultDictVariable(ConstDictVariable):
  195. def __init__(self, items, user_cls, default_factory=None, **kwargs):
  196. super().__init__(items, user_cls, **kwargs)
  197. assert user_cls is collections.defaultdict
  198. self.default_factory = default_factory
  199. def call_method(
  200. self,
  201. tx,
  202. name,
  203. args: "List[VariableTracker]",
  204. kwargs: "Dict[str, VariableTracker]",
  205. ) -> "VariableTracker":
  206. from . import ListVariable, TupleVariable
  207. options = VariableTracker.propagate(self, args, kwargs.values())
  208. if name == "__getitem__":
  209. k = ConstDictVariable.get_key(args[0])
  210. if k in self.items:
  211. return self.getitem_const(args[0])
  212. else:
  213. if self.default_factory is None:
  214. raise KeyError(f"{k}")
  215. else:
  216. if istensor(k):
  217. tx.store_dict_key(global_key_name(k), k)
  218. new_val = collections.OrderedDict(self.items)
  219. if self.default_factory is list:
  220. default_var = ListVariable([], mutable_local=MutableLocal())
  221. elif self.default_factory is tuple:
  222. default_var = TupleVariable([], mutable_local=MutableLocal())
  223. elif self.default_factory is dict:
  224. default_var = ConstDictVariable(
  225. {}, dict, mutable_local=MutableLocal()
  226. )
  227. else:
  228. unimplemented(
  229. f"defaultdict with default_factory = {self.default_factory}"
  230. )
  231. new_val[k] = default_var
  232. new_rec_contains = self.recursively_contains.union(
  233. default_var.recursively_contains
  234. )
  235. if default_var.mutable_local is not None:
  236. new_rec_contains.add(default_var.mutable_local)
  237. tx.replace_all(
  238. self, self.modifed(new_val, new_rec_contains, **options)
  239. )
  240. return default_var
  241. else:
  242. return super().call_method(tx, name, args, kwargs)
  243. class DataClassVariable(ConstDictVariable):
  244. """
  245. This is a bit of a hack to deal with
  246. transformers.file_utils.ModelOutput() from huggingface.
  247. ModelOutput causes trouble because it a a mix of a dataclass and a
  248. OrderedDict and it calls super() methods implemented in C.
  249. """
  250. # ModelOutput() excludes None, though generic datclasses don't
  251. include_none = False
  252. @staticmethod
  253. @functools.lru_cache(None)
  254. def _patch_once():
  255. from transformers.file_utils import ModelOutput
  256. for obj in ModelOutput.__dict__.values():
  257. if callable(obj):
  258. skip_code(obj.__code__)
  259. @staticmethod
  260. def is_matching_cls(cls):
  261. try:
  262. from transformers.file_utils import ModelOutput
  263. return issubclass(cls, ModelOutput)
  264. except ImportError:
  265. return False
  266. @classmethod
  267. def is_matching_object(cls, obj):
  268. return cls.is_matching_cls(type(obj))
  269. @classmethod
  270. def create(cls, user_cls, args, kwargs, options):
  271. DataClassVariable._patch_once()
  272. skip_code(user_cls.__init__.__code__)
  273. keys = [f.name for f in dataclasses.fields(user_cls)]
  274. bound = inspect.signature(user_cls).bind(*args, **kwargs)
  275. bound.apply_defaults()
  276. assert set(bound.arguments.keys()) == set(keys)
  277. items = collections.OrderedDict()
  278. for key in keys:
  279. val = bound.arguments[key]
  280. if isinstance(val, VariableTracker):
  281. items[key] = val
  282. else:
  283. if cls.include_none:
  284. assert variables.ConstantVariable.is_literal(val)
  285. items[key] = variables.ConstantVariable(val)
  286. else:
  287. assert val is None, f"unexpected {val}"
  288. if len(items) == 1 and not isinstance(items[keys[0]], variables.TensorVariable):
  289. unimplemented("DataClassVariable iterator constructor")
  290. # TODO(jansel): implement unpacking logic in ModelOutput.__post_init__
  291. return cls(items, user_cls, **options)
  292. @classmethod
  293. def wrap(cls, builder, obj):
  294. user_cls = type(obj)
  295. keys = [f.name for f in dataclasses.fields(user_cls)]
  296. excluded = []
  297. items = collections.OrderedDict()
  298. for key in keys:
  299. # __init__ function of a dataclass might not have yet defined the key
  300. if hasattr(obj, key):
  301. val = getattr(obj, key)
  302. var = builder.__class__(
  303. tx=builder.tx, source=AttrSource(builder.source, key)
  304. )(val)
  305. if val is not None or cls.include_none:
  306. items[key] = var
  307. else:
  308. excluded.append(var)
  309. return cls(
  310. items, user_cls, **VariableTracker.propagate(excluded, items.values())
  311. )
  312. def __init__(self, items, user_cls, **options):
  313. super().__init__(items, user_cls, **options)
  314. assert self.is_matching_cls(user_cls)
  315. def as_proxy(self):
  316. raise NotImplementedError()
  317. def reconstruct(self, codegen):
  318. codegen.extend_output([codegen._create_load_const(self.user_cls)])
  319. keys = tuple(self.items.keys())
  320. for key in keys:
  321. codegen(self.items[key])
  322. return [
  323. codegen.create_load_const(keys),
  324. create_instruction("CALL_FUNCTION_KW", len(keys)),
  325. ]
  326. def call_method(
  327. self,
  328. tx,
  329. name,
  330. args: "List[VariableTracker]",
  331. kwargs: "Dict[str, VariableTracker]",
  332. ) -> "VariableTracker":
  333. options = VariableTracker.propagate(self, args, kwargs.values())
  334. if name == "__getitem__":
  335. assert not kwargs and len(args) == 1
  336. index = args[0].as_python_constant()
  337. if isinstance(index, str):
  338. return self.items[index].add_options(options)
  339. else:
  340. return (
  341. self.call_method(tx, "to_tuple", [], {})
  342. .call_method(tx, "__getitem__", args, kwargs)
  343. .add_options(options)
  344. )
  345. elif name == "to_tuple":
  346. assert not (args or kwargs)
  347. return variables.TupleVariable(list(self.items.values()), **options)
  348. elif name == "__setattr__":
  349. name = "__setitem__"
  350. return super().call_method(tx, name, args, kwargs)
  351. def var_getattr(self, tx, name: str) -> "VariableTracker":
  352. if name in self.items:
  353. return self.call_method(
  354. tx, "__getitem__", [variables.ConstantVariable(name)], {}
  355. )
  356. elif not self.include_none:
  357. defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)}
  358. if name in defaults:
  359. assert variables.ConstantVariable.is_literal(defaults[name])
  360. return variables.ConstantVariable(defaults[name]).add_options(self)
  361. super().var_getattr(tx, name)
  362. class HFPretrainedConfigVariable(VariableTracker):
  363. """
  364. Hack for HuggingFace PretrainedConfig
  365. """
  366. @staticmethod
  367. def is_matching_cls(cls):
  368. try:
  369. from transformers.configuration_utils import PretrainedConfig
  370. return issubclass(cls, PretrainedConfig)
  371. except ImportError:
  372. return False
  373. @classmethod
  374. def is_matching_object(cls, obj):
  375. return cls.is_matching_cls(type(obj))
  376. def __init__(self, obj, **kwargs):
  377. super().__init__(**kwargs)
  378. self.obj = obj
  379. assert self.is_matching_cls(type(obj))
  380. def var_getattr(self, tx, name: str) -> "VariableTracker":
  381. from . import ConstantVariable
  382. return ConstantVariable(getattr(self.obj, name))