source.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. import collections
  2. import dataclasses
  3. import enum
  4. from typing import Any, Optional, Union
  5. from torch._guards import GuardSource, Source
  6. from . import utils
  7. from .bytecode_transformation import create_instruction
  8. from .utils import enum_repr, rename_implicit
  9. _GUARD_SOURCE_NN_MODULE = {
  10. GuardSource.LOCAL: GuardSource.LOCAL_NN_MODULE,
  11. GuardSource.GLOBAL: GuardSource.GLOBAL_NN_MODULE,
  12. GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_NN_MODULE,
  13. GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_NN_MODULE,
  14. }
  15. _GUARD_SOURCE_NOT_NN_MODULE = {
  16. GuardSource.LOCAL: GuardSource.LOCAL,
  17. GuardSource.GLOBAL: GuardSource.GLOBAL,
  18. GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL,
  19. GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL,
  20. }
  21. def is_constant_source(source):
  22. if isinstance(source, ConstantSource):
  23. return True
  24. try:
  25. if source.guard_source() == GuardSource.CONSTANT:
  26. return True
  27. except NotImplementedError:
  28. pass
  29. return False
  30. def is_input_source(source):
  31. return source.guard_source() in [
  32. GuardSource.LOCAL,
  33. GuardSource.GLOBAL,
  34. GuardSource.LOCAL_NN_MODULE,
  35. GuardSource.GLOBAL_NN_MODULE,
  36. ]
  37. @dataclasses.dataclass
  38. class LocalSource(Source):
  39. local_name: str
  40. def reconstruct(self, codegen):
  41. return [codegen.create_load(self.local_name)]
  42. def guard_source(self):
  43. return GuardSource.LOCAL
  44. def name(self):
  45. return rename_implicit(self.local_name)
  46. @dataclasses.dataclass
  47. class LocalInputSource(LocalSource):
  48. pos: int
  49. @dataclasses.dataclass
  50. class RandomValueSource(Source):
  51. random_call_index: int
  52. def guard_source(self):
  53. return GuardSource.RANDOM_VALUE
  54. def reconstruct(self, codegen):
  55. return [
  56. codegen.create_load(codegen.tx.output.random_values_var),
  57. codegen.create_load_const(self.random_call_index),
  58. create_instruction("BINARY_SUBSCR"),
  59. ]
  60. def name(self):
  61. return rename_implicit(f"random_value_{self.random_call_index}")
  62. @dataclasses.dataclass
  63. class GlobalSource(Source):
  64. global_name: str
  65. def reconstruct(self, codegen):
  66. return [codegen.create_load_global(self.global_name, add=True)]
  67. def guard_source(self):
  68. return GuardSource.GLOBAL
  69. def name(self):
  70. return self.global_name
  71. @dataclasses.dataclass
  72. class GlobalWeakRefSource(Source):
  73. global_name: str
  74. def reconstruct(self, codegen):
  75. return [
  76. codegen.create_load_global(self.global_name, add=True),
  77. create_instruction("CALL_FUNCTION", 0),
  78. ]
  79. def guard_source(self):
  80. return GuardSource.GLOBAL
  81. def name(self):
  82. return f"{self.global_name}()"
  83. @dataclasses.dataclass
  84. class AttrSource(Source):
  85. base: Source
  86. member: str
  87. def __init__(self, base, member):
  88. super().__init__()
  89. assert base, "Can't construct an AttrSource without a valid base source"
  90. if "." in member:
  91. member_parts = member.split(".")
  92. self.base = AttrSource(base, ".".join(member_parts[:-1]))
  93. self.member = member_parts[-1]
  94. else:
  95. self.base = base
  96. self.member = member
  97. assert self.base is not None
  98. def reconstruct(self, codegen):
  99. return self.base.reconstruct(codegen) + codegen.create_load_attrs(self.member)
  100. def guard_source(self):
  101. return self.base.guard_source()
  102. def name(self):
  103. if self.member.isnumeric():
  104. return f"getattr({self.base.name()}, {self.member!r})"
  105. return f"{self.base.name()}.{self.member}"
  106. class TensorProperty(enum.Enum):
  107. SIZE = 0
  108. STRIDE = 1
  109. STORAGE_OFFSET = 2
  110. @dataclasses.dataclass
  111. class TensorPropertySource(Source):
  112. base: Source
  113. prop: TensorProperty
  114. idx: Optional[int] = None # None for STORAGE_OFFSET
  115. def __post_init__(self):
  116. assert self.base is not None
  117. if self.prop is TensorProperty.STORAGE_OFFSET:
  118. assert self.idx is None
  119. else:
  120. assert self.idx is not None
  121. def reconstruct(self, codegen):
  122. raise NotImplementedError()
  123. def guard_source(self):
  124. return self.base.guard_source()
  125. def name(self):
  126. if self.prop is TensorProperty.SIZE:
  127. return f"{self.base.name()}.size()[{self.idx}]"
  128. elif self.prop is TensorProperty.STRIDE:
  129. return f"{self.base.name()}.stride()[{self.idx}]"
  130. elif self.prop is TensorProperty.STORAGE_OFFSET:
  131. assert self.idx is None
  132. return f"{self.base.name()}.storage_offset()"
  133. else:
  134. raise AssertionError(f"unhandled {self.prop}")
  135. @dataclasses.dataclass
  136. class NegateSource(Source):
  137. base: Source
  138. def __post_init__(self):
  139. assert self.base is not None
  140. def reconstruct(self, codegen):
  141. raise NotImplementedError()
  142. def guard_source(self):
  143. return self.base.guard_source()
  144. def name(self):
  145. # NB: use method call so that function stripping regexes work
  146. return f"{self.base.name()}.__neg__()"
  147. @dataclasses.dataclass
  148. class DefaultsSource(Source):
  149. base: Source
  150. idx_key: Union[int, str]
  151. is_kw: bool
  152. field: str
  153. def __init__(self, base, idx_key, is_kw=False):
  154. super().__init__()
  155. assert (
  156. base
  157. ), "Base must be a valid source in order to properly track and guard this Defaults to its origin."
  158. self.base = base
  159. self.idx_key = idx_key
  160. self.is_kw = is_kw
  161. if self.is_kw:
  162. assert isinstance(idx_key, str)
  163. self.field = "__kwdefaults__"
  164. self._name = f"{self.base.name()}.{self.field}['{self.idx_key}']"
  165. else:
  166. assert isinstance(idx_key, int)
  167. self.field = "__defaults__"
  168. self._name = f"{self.base.name()}.{self.field}[{self.idx_key}]"
  169. def reconstruct(self, codegen):
  170. instrs = self.base.reconstruct(codegen)
  171. instrs.extend(codegen.create_load_attrs(self.field))
  172. instrs.extend(
  173. [
  174. codegen.create_load_const(self.idx_key),
  175. create_instruction("BINARY_SUBSCR"),
  176. ]
  177. )
  178. return instrs
  179. def guard_source(self):
  180. return self.base.guard_source()
  181. def name(self):
  182. return self._name
  183. @dataclasses.dataclass
  184. class GetItemSource(Source):
  185. base: Source
  186. index: Any
  187. def __post_init__(self):
  188. assert self.base is not None
  189. def reconstruct(self, codegen):
  190. instrs = self.base.reconstruct(codegen)
  191. if isinstance(self.index, Source):
  192. instrs.extend(self.index.reconstruct(codegen))
  193. else:
  194. instrs.append(codegen.create_load_const(self.index))
  195. instrs.append(create_instruction("BINARY_SUBSCR"))
  196. return instrs
  197. def guard_source(self):
  198. return self.base.guard_source()
  199. def name(self):
  200. if isinstance(self.index, Source):
  201. return f"{self.base.name()}[{self.index.name()}]"
  202. else:
  203. if isinstance(self.index, enum.Enum):
  204. return f"{self.base.name()}[{enum_repr(self.index)}]"
  205. else:
  206. return f"{self.base.name()}[{self.index!r}]"
  207. @dataclasses.dataclass
  208. class TupleIteratorGetItemSource(GetItemSource):
  209. def reconstruct(self, codegen):
  210. codegen.load_import_from(utils.__name__, "tuple_iterator_getitem")
  211. return self.base.reconstruct(codegen) + [
  212. codegen.create_load_const(self.index),
  213. create_instruction("CALL_FUNCTION", 2),
  214. ]
  215. def name(self):
  216. return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
  217. @dataclasses.dataclass
  218. class TypeSource(Source):
  219. base: Source
  220. def __post_init__(self):
  221. assert self.base is not None
  222. def reconstruct(self, codegen):
  223. codegen.load_import_from("builtins", "type")
  224. return self.base.reconstruct(codegen) + [create_instruction("CALL_FUNCTION", 1)]
  225. def guard_source(self):
  226. return self.base.guard_source()
  227. def name(self):
  228. return f"type({self.base.name()})"
  229. @dataclasses.dataclass
  230. class SuperSource(Source):
  231. type: Source
  232. obj: Source
  233. def __post_init__(self):
  234. assert self.type is not None
  235. assert self.obj is not None
  236. def reconstruct(self, codegen):
  237. codegen.load_import_from("builtins", "super")
  238. return (
  239. self.type.reconstruct(codegen)
  240. + self.obj.reconstruct(codegen)
  241. + [create_instruction("CALL_FUNCTION", 2)]
  242. )
  243. def guard_source(self):
  244. return self.obj.guard_source()
  245. def name(self):
  246. return f"super({self.type.name()}, {self.obj.name()})"
  247. @dataclasses.dataclass
  248. class ODictGetItemSource(Source):
  249. base: Source
  250. index: Any
  251. def __post_init__(self):
  252. assert self.base is not None
  253. def reconstruct(self, codegen):
  254. return (
  255. [codegen._create_load_const(collections.OrderedDict.__getitem__)]
  256. + self.base.reconstruct(codegen)
  257. + [
  258. codegen.create_load_const(self.index),
  259. create_instruction("CALL_FUNCTION", 2),
  260. ]
  261. )
  262. def guard_source(self):
  263. return self.base.guard_source()
  264. def name(self):
  265. return f"___odict_getitem({self.base.name()}, {self.index!r})"
  266. @dataclasses.dataclass
  267. class NNModuleSource(Source):
  268. inner: Source
  269. def reconstruct(self, codegen):
  270. return self.inner.reconstruct(codegen)
  271. def guard_source(self):
  272. return _GUARD_SOURCE_NN_MODULE[self.inner.guard_source()]
  273. def name(self):
  274. return self.inner.name()
  275. class NotNNModuleSource(NNModuleSource):
  276. def guard_source(self):
  277. return _GUARD_SOURCE_NOT_NN_MODULE[self.inner.guard_source()]
  278. @dataclasses.dataclass
  279. class ConstantSource(Source):
  280. source_name: str
  281. def reconstruct(self, codegen):
  282. return [codegen.create_load_global(self.source_name, add=False)]
  283. def guard_source(self):
  284. return GuardSource.CONSTANT
  285. def name(self):
  286. return self.source_name
  287. def make_guard(self, fn, is_volatile=False):
  288. raise NotImplementedError()
  289. # This is a synthetic source that is associated with the singleton
  290. # shape env guard we always register for all frames. We get the actual
  291. # guard contents from the ambient ShapeEnv
  292. @dataclasses.dataclass
  293. class ShapeEnvSource(Source):
  294. def name(self):
  295. return ""
  296. def guard_source(self):
  297. return GuardSource.SHAPE_ENV