side_effects.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. import collections
  2. import dataclasses
  3. import inspect
  4. from typing import Any, Dict, List, Optional
  5. import torch.nn
  6. from . import utils, variables
  7. from .bytecode_transformation import create_instruction
  8. from .codegen import PyCodegen
  9. from .source import LocalSource, Source
  10. from .utils import object_new
  11. from .variables.base import VariableTracker
  12. @dataclasses.dataclass
  13. class MutableSideEffects:
  14. """
  15. VariableTracker.mutable_local marker to indicate a list passed as
  16. an input that if we mutate we need to re-apply those mutations after
  17. the graph runs.
  18. """
  19. source: Source
  20. is_modified: bool = False
  21. def __hash__(self):
  22. return id(self)
  23. def __eq__(self, other):
  24. return self is other
  25. @dataclasses.dataclass
  26. class AttributeMutation:
  27. """
  28. VariableTracker.mutable_local marker to track changes to attributes
  29. """
  30. source: Source
  31. class AttributeMutationExisting(AttributeMutation):
  32. def __hash__(self):
  33. return id(self)
  34. def __eq__(self, other):
  35. return self is other
  36. @dataclasses.dataclass
  37. class AttributeMutationNew(AttributeMutation):
  38. cls_source: Source
  39. def __hash__(self):
  40. return id(self)
  41. def __eq__(self, other):
  42. return self is other
  43. class SideEffects:
  44. """
  45. Track side effects (list mutation, setattr, etc) that need to be
  46. applied after an FX graph is run.
  47. """
  48. id_to_variable: Dict[int, VariableTracker]
  49. store_attr_mutations: Dict[AttributeMutation, Dict[str, VariableTracker]]
  50. keepalive: List[Any]
  51. def __init__(self, id_to_variable=None, store_attr_mutations=None, keepalive=None):
  52. super().__init__()
  53. self.id_to_variable = id_to_variable or collections.OrderedDict()
  54. self.store_attr_mutations = store_attr_mutations or collections.OrderedDict()
  55. self.keepalive = keepalive or []
  56. def __eq__(self, other: object) -> bool:
  57. assert isinstance(other, SideEffects)
  58. # NB: do NOT test keepalive
  59. return (
  60. self.id_to_variable == other.id_to_variable
  61. and self.store_attr_mutations == other.store_attr_mutations
  62. )
  63. def diff(self, other: "SideEffects") -> Optional[str]:
  64. if self.id_to_variable != other.id_to_variable:
  65. sk_itv = self.id_to_variable.keys()
  66. ok_itv = other.id_to_variable.keys()
  67. if sk_itv != ok_itv:
  68. return f"id_to_variable keys: {sk_itv} != {ok_itv}"
  69. # Feel free to augment this with more fancy diffing logic
  70. # if needed for debugging
  71. return "id_to_variable: unknown diff"
  72. elif self.store_attr_mutations != other.store_attr_mutations:
  73. sk_sam = self.store_attr_mutations.keys()
  74. ok_sam = other.store_attr_mutations.keys()
  75. if sk_sam != ok_sam:
  76. return f"store_attr_mutations keys: {sk_sam} != {ok_sam}"
  77. return "store_attr_mutations: unknown diff"
  78. else:
  79. return None
  80. def clone(self):
  81. """Create a shallow copy"""
  82. return self.__class__(
  83. id_to_variable=collections.OrderedDict(self.id_to_variable),
  84. store_attr_mutations=collections.OrderedDict(
  85. (k, collections.OrderedDict(v))
  86. for k, v in self.store_attr_mutations.items()
  87. ),
  88. keepalive=list(self.keepalive),
  89. )
  90. def apply(self, fn, cache=None, skip_fn=lambda _: False):
  91. if cache is None:
  92. cache = dict()
  93. self.id_to_variable = collections.OrderedDict(
  94. (k, VariableTracker.apply(fn, v, cache, skip_fn))
  95. for k, v in self.id_to_variable.items()
  96. )
  97. self.store_attr_mutations = collections.OrderedDict(
  98. (k, VariableTracker.apply(fn, v, cache, skip_fn))
  99. for k, v in self.store_attr_mutations.items()
  100. )
  101. def __contains__(self, item):
  102. return id(item) in self.id_to_variable
  103. def __getitem__(self, item):
  104. return self.id_to_variable[id(item)]
  105. def store_attr(self, item: VariableTracker, name: str, value: VariableTracker):
  106. assert self.is_attribute_mutation(item)
  107. if item.mutable_local not in self.store_attr_mutations:
  108. self.store_attr_mutations[item.mutable_local] = collections.OrderedDict()
  109. self.store_attr_mutations[item.mutable_local][name] = value
  110. def load_attr(self, item, name):
  111. assert self.is_attribute_mutation(item)
  112. return self.store_attr_mutations[item.mutable_local][name]
  113. def store_cell(self, cellvar, value):
  114. assert isinstance(cellvar, variables.NewCellVariable)
  115. assert isinstance(value, variables.VariableTracker)
  116. self.store_attr(cellvar, "cell_contents", value)
  117. def load_cell(self, cellvar):
  118. assert isinstance(cellvar, variables.NewCellVariable)
  119. return self.load_attr(cellvar, "cell_contents")
  120. def load_global(self, gvar: VariableTracker, name: str):
  121. assert isinstance(gvar, variables.VariableTracker)
  122. return self.load_attr(gvar, name)
  123. def store_global(self, gvar: VariableTracker, name: str, value: VariableTracker):
  124. assert isinstance(gvar, variables.VariableTracker)
  125. assert isinstance(value, variables.VariableTracker)
  126. self.store_attr(gvar, name, value)
  127. @staticmethod
  128. def cls_supports_mutation_side_effects(cls):
  129. return inspect.getattr_static(cls, "__setattr__", None) in (
  130. object.__setattr__,
  131. torch.nn.Module.__setattr__,
  132. )
  133. def is_attribute_mutation(self, item):
  134. return isinstance(item.mutable_local, AttributeMutation)
  135. def is_modified(self, item):
  136. if isinstance(item.mutable_local, AttributeMutationNew):
  137. return True
  138. if self.is_attribute_mutation(item):
  139. return item.mutable_local in self.store_attr_mutations
  140. return item.mutable_local.is_modified
  141. def _track_obj(
  142. self,
  143. source: Source,
  144. item: Any,
  145. variable: VariableTracker,
  146. mutable_cls=MutableSideEffects,
  147. ):
  148. """Start tracking a new variable for mutation"""
  149. variable = variable.clone(mutable_local=mutable_cls(source), source=source)
  150. self.id_to_variable[id(item)] = variable
  151. self.keepalive.append(item)
  152. return variable
  153. track_list = _track_obj
  154. track_dict = _track_obj
  155. def track_object_existing(
  156. self,
  157. source: Source,
  158. item: Any,
  159. variable: VariableTracker,
  160. ):
  161. return self._track_obj(
  162. source, item, variable, mutable_cls=AttributeMutationExisting
  163. )
  164. def track_object_new(
  165. self,
  166. cls_source: Source,
  167. user_cls: Any,
  168. variable_cls: Any,
  169. options,
  170. ):
  171. obj = object_new(user_cls)
  172. variable = variable_cls(
  173. obj,
  174. mutable_local=AttributeMutationNew(None, cls_source),
  175. **options,
  176. )
  177. self.id_to_variable[id(obj)] = variable
  178. self.keepalive.append(obj)
  179. return variable
  180. def track_cell_new(
  181. self,
  182. ):
  183. obj = object()
  184. variable = variables.NewCellVariable(
  185. mutable_local=AttributeMutationNew(None, None),
  186. )
  187. self.id_to_variable[id(obj)] = variable
  188. self.keepalive.append(obj)
  189. return variable
  190. def track_cell_existing(self, source: Source, item: Any):
  191. variable = variables.NewCellVariable(
  192. mutable_local=AttributeMutationExisting(source),
  193. )
  194. self.id_to_variable[id(item)] = variable
  195. self.keepalive.append(item)
  196. return variable
  197. def track_global_existing(self, source: Source, item: Any):
  198. variable = variables.NewGlobalVariable(
  199. mutable_local=AttributeMutationExisting(source),
  200. )
  201. self.id_to_variable[id(item)] = variable
  202. self.keepalive.append(item)
  203. return variable
  204. def prune_dead_object_new(self, tx):
  205. live_new_objects = set()
  206. skip_obj = None
  207. def visit(var: VariableTracker):
  208. if (
  209. isinstance(var.mutable_local, AttributeMutationNew)
  210. and var.mutable_local is not skip_obj
  211. ):
  212. live_new_objects.add(var.mutable_local)
  213. return var
  214. def is_live(var: VariableTracker):
  215. if isinstance(var, AttributeMutationNew):
  216. return var in live_new_objects
  217. if isinstance(var, VariableTracker):
  218. return is_live(var.mutable_local)
  219. return True
  220. VariableTracker.apply(visit, (tx.stack, tx.symbolic_locals))
  221. for var in self.id_to_variable.values():
  222. if not isinstance(var.mutable_local, AttributeMutationNew):
  223. VariableTracker.apply(visit, var)
  224. for skip_obj, setattrs in self.store_attr_mutations.items():
  225. VariableTracker.apply(visit, setattrs)
  226. self.id_to_variable = collections.OrderedDict(
  227. (k, v) for k, v in self.id_to_variable.items() if is_live(v)
  228. )
  229. self.store_attr_mutations = collections.OrderedDict(
  230. (k, v) for k, v in self.store_attr_mutations.items() if is_live(k)
  231. )
  232. def mutation(self, oldvar, newvar):
  233. return newvar.clone(
  234. mutable_local=MutableSideEffects(oldvar.mutable_local.source, True)
  235. )
  236. def _get_modified_vars(self):
  237. return [var for var in self.id_to_variable.values() if self.is_modified(var)]
  238. def codegen_save_tempvars(self, cg: PyCodegen):
  239. for var in self._get_modified_vars():
  240. if isinstance(
  241. var.mutable_local, (AttributeMutationExisting, AttributeMutationNew)
  242. ) and isinstance(var, variables.NewCellVariable):
  243. cg.load_import_from(utils.__name__, "make_cell")
  244. cg.extend_output([create_instruction("CALL_FUNCTION", 0)])
  245. cg.add_cache(var)
  246. if isinstance(var.mutable_local, AttributeMutationNew):
  247. var.mutable_local.source = LocalSource(cg.tempvars[var])
  248. elif isinstance(var.mutable_local, AttributeMutationNew):
  249. cg.load_import_from(utils.__name__, "object_new")
  250. cg(var.mutable_local.cls_source)
  251. cg.extend_output([create_instruction("CALL_FUNCTION", 1)])
  252. cg.add_cache(var)
  253. var.mutable_local.source = LocalSource(cg.tempvars[var])
  254. elif var in cg.tempvars:
  255. assert cg.tempvars.get(var) is None
  256. # subsequent usage should point to the original variable
  257. cg(var.mutable_local.source)
  258. cg.add_cache(var)
  259. def codegen_update_mutated(self, cg: PyCodegen):
  260. suffixes = []
  261. for var in self._get_modified_vars():
  262. if isinstance(var, variables.ListVariable):
  263. # old[:] = new
  264. cg(var, allow_cache=False)
  265. cg(var.mutable_local.source)
  266. cg.extend_output(
  267. [
  268. cg.create_load_const(None),
  269. cg.create_load_const(None),
  270. create_instruction("BUILD_SLICE", 2),
  271. ]
  272. )
  273. suffixes.append([create_instruction("STORE_SUBSCR")])
  274. elif isinstance(var, variables.ConstDictVariable):
  275. cg.tx.output.update_co_names("clear")
  276. cg.tx.output.update_co_names("update")
  277. cg(var.mutable_local.source)
  278. cg.extend_output([create_instruction("LOAD_METHOD", "update")])
  279. cg(var, allow_cache=False)
  280. cg(var.mutable_local.source)
  281. cg.extend_output([create_instruction("LOAD_METHOD", "clear")])
  282. suffixes.append(
  283. [
  284. create_instruction("CALL_METHOD", 0), # clear
  285. create_instruction("POP_TOP"),
  286. create_instruction("CALL_METHOD", 1), # update
  287. create_instruction("POP_TOP"),
  288. ]
  289. )
  290. elif self.is_attribute_mutation(var):
  291. for name, value in self.store_attr_mutations.get(
  292. var.mutable_local, {}
  293. ).items():
  294. if isinstance(var, variables.NewGlobalVariable):
  295. cg.tx.output.update_co_names(name)
  296. cg(value)
  297. suffixes.append([create_instruction("STORE_GLOBAL", name)])
  298. else:
  299. cg.tx.output.update_co_names(name)
  300. cg(value)
  301. cg(var.mutable_local.source)
  302. suffixes.append([create_instruction("STORE_ATTR", name)])
  303. else:
  304. raise AssertionError(type(var))
  305. # do all the actual mutations at the very end to handle dependencies
  306. for suffix in reversed(suffixes):
  307. cg.extend_output(suffix)
  308. def is_empty(self):
  309. return not any(map(self.is_modified, self.id_to_variable.values()))