guards.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729
  1. import builtins
  2. import collections
  3. import logging
  4. import math
  5. import os
  6. import re
  7. import types
  8. import weakref
  9. from inspect import currentframe, getframeinfo
  10. from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
  11. from weakref import ReferenceType
  12. import torch
  13. from torch._guards import (
  14. DuplicateInputs,
  15. Guard,
  16. GuardBuilderBase,
  17. GuardEnvExpr,
  18. GuardSource,
  19. Source,
  20. )
  21. from torch.fx.experimental.symbolic_shapes import SYMPY_INTERP
  22. from . import config, convert_frame, mutation_guard
  23. from .eval_frame import set_guard_error_hook, set_guard_fail_hook
  24. from .exc import unimplemented
  25. from .types import GuardedCode, GuardFail, GuardFn # noqa: F401
  26. from .utils import (
  27. dict_const_keys,
  28. dict_const_keys_repr,
  29. dict_param_key_ids,
  30. guard_failures,
  31. HAS_NUMPY,
  32. istype,
  33. np,
  34. orig_code_map,
  35. rename_implicit,
  36. tuple_iterator_getitem,
  37. tuple_iterator_len,
  38. )
  39. log = logging.getLogger(__name__)
  40. TensorGuards = torch._C._dynamo.guards.TensorGuards
  41. check_obj_id = torch._C._dynamo.guards.check_obj_id
  42. check_type_id = torch._C._dynamo.guards.check_type_id
  43. CLOSURE_VARS = collections.OrderedDict(
  44. [
  45. ("___check_type_id", check_type_id),
  46. ("___check_obj_id", check_obj_id),
  47. ("___is_grad_enabled", torch.is_grad_enabled),
  48. ("___odict_getitem", collections.OrderedDict.__getitem__),
  49. ("___dict_param_key_ids", dict_param_key_ids),
  50. ("___dict_const_keys", dict_const_keys),
  51. ("___tuple_iterator_len", tuple_iterator_len),
  52. ("___tuple_iterator_getitem", tuple_iterator_getitem),
  53. ("__math_isnan", math.isnan),
  54. ("inf", float("inf")),
  55. ]
  56. )
  57. def strip_function_call(name):
  58. """
  59. "___odict_getitem(a, 1)" => "a"
  60. """
  61. m = re.search(r"([a-z0-9_]+)\(([^(),]+)[^()]*\)", name)
  62. if m and m.group(1) != "slice":
  63. return strip_function_call(m.group(2))
  64. return strip_getattr_getitem(name)
  65. def strip_getattr_getitem(name):
  66. """
  67. "a[1]" => "a"
  68. "a.foo" => "a"
  69. """
  70. return re.split(r"[.\[]", name)[0]
  71. class GuardBuilder(GuardBuilderBase):
  72. def __init__(
  73. self,
  74. id_ref: Callable[[Type[object]], str],
  75. source_ref: Callable[[Source], str],
  76. scope: Optional[Dict[str, object]],
  77. check_fn_manager: "CheckFunctionManager",
  78. renames=True,
  79. ):
  80. self.id_ref = id_ref
  81. self.source_ref = source_ref
  82. if scope:
  83. if renames:
  84. scope = {rename_implicit(k): v for k, v in scope.items()}
  85. else:
  86. scope = dict()
  87. self.scope: Dict[str, object] = scope
  88. self.scope["__builtins__"] = builtins.__dict__.copy()
  89. for (
  90. name,
  91. package_module,
  92. ) in torch.package.package_importer._package_imported_modules.items():
  93. name = name.replace(">", "_").replace("<", "_").replace(".", "_dot_")
  94. # Write the package module into the scope so that we can import it
  95. self.scope["__builtins__"][name] = package_module # type: ignore[index]
  96. # Write the demangled name to the scope so that we can use it
  97. self.scope[name] = package_module
  98. self.argnames: List[str] = []
  99. # Code is python expression strings generated for each guard
  100. self.code: List[str] = []
  101. # shape_env_code is only used by local_builder and is used for
  102. # shape env code. This exists only because we need to make sure
  103. # shape env guards get run after tensor match guards (since the
  104. # tensor match guards make sure we actually have tensors)
  105. self.shape_env_code: List[str] = []
  106. # Most of the time, we generate Python code in a guard to directly
  107. # check various properties. However, tensors are a bit special;
  108. # it is too slow to check their properties one-by-one in Python.
  109. # Instead, there is a C++ function TensorGuards.check which takes
  110. # all of the tensor arguments and checks them all against compile-time
  111. # examples entirely in C++. Thus, every time we process a
  112. # TENSOR_MATCH guard, we just add another entry to
  113. # tensor_check_names/tensor_check_examples, saying "for this local,
  114. # check it against this example", and it all ends up getting
  115. # swept up into a single call to ___check_tensors. Invariant:
  116. # len(tensor_check_names) == len(tensor_check_examples).
  117. self.tensor_check_names: List[str] = []
  118. self.tensor_check_examples: List[torch.Tensor] = []
  119. self.tensor_check_ids: Dict[str, int] = {}
  120. self.check_fn_manager: CheckFunctionManager = check_fn_manager
  121. # Warning: use this with care! This lets you access what the current
  122. # value of the value you are guarding on is. You probably don't want
  123. # to actually durably save this value though (because it's specific
  124. # to this frame!) Instead, you should be reading out some property
  125. # (like its type) which is what you permanently install into the
  126. # guard code.
  127. def get(self, name: str) -> Any:
  128. return eval(name, self.scope, CLOSURE_VARS)
  129. # Registers the usage of the source name referenced by the
  130. # string (or stored in the Guard) as being guarded upon. It's important
  131. # to call this before generating some code that makes use of 'guard',
  132. # because without this call, we won't actually bind the variable
  133. # you reference in the actual guard closure (oops!)
  134. def arg_ref(self, guard: Union[str, Guard]) -> str:
  135. name: str
  136. if isinstance(guard, str):
  137. name = guard
  138. else:
  139. name = guard.name
  140. base = strip_getattr_getitem(strip_function_call(name))
  141. if base not in self.argnames:
  142. if re.match(r"^\d+$", base):
  143. log.warning(f"invalid var name: {guard}")
  144. self.argnames.append(base)
  145. return name
  146. def TYPE_MATCH(self, guard: Guard):
  147. # ___check_type_id is same as `id(type(x)) == y`
  148. t = type(self.get(guard.name))
  149. obj_id = self.id_ref(t)
  150. code = f"___check_type_id({self.arg_ref(guard)}, {obj_id})"
  151. self._produce_guard_code(guard, [code])
  152. def ID_MATCH(self, guard: Guard):
  153. # ___check_obj_id is same as `id(x) == y`
  154. m = re.match(r"^type\((.+)\)$", guard.name)
  155. if m:
  156. # optional optimization to produce cleaner/faster guard code
  157. return self.TYPE_MATCH(
  158. Guard(m.group(1), guard.source, GuardBuilder.TYPE_MATCH)
  159. )
  160. code = f"___check_obj_id({self.arg_ref(guard)}, {self.id_ref(self.get(guard.name))})"
  161. self._produce_guard_code(guard, [code])
  162. def NAME_MATCH(self, guard: Guard):
  163. obj = self.get(guard.name)
  164. code = f"{self.arg_ref(guard)}.__name__ == {obj.__name__}"
  165. self._produce_guard_code(guard, [code])
  166. def HASATTR(self, guard: Guard):
  167. m = re.match(r"^(.*)[.]([a-zA-Z0-9_]+)$", guard.name)
  168. assert m, f"invalid hasattr check {guard.name}"
  169. base, attr = m.group(1, 2)
  170. ref = self.arg_ref(base)
  171. val = hasattr(self.get(base), attr)
  172. code = None
  173. if val:
  174. code = f"hasattr({ref}, {attr!r})"
  175. else:
  176. code = f"not hasattr({ref}, {attr!r})"
  177. self._produce_guard_code(guard, [code], provided_guarded_object=self.get(base))
  178. def EQUALS_MATCH(self, guard: Guard):
  179. ref = self.arg_ref(guard)
  180. val = self.get(guard.name)
  181. t = type(val)
  182. np_types = (
  183. (
  184. np.int8,
  185. np.int16,
  186. np.int32,
  187. np.int64,
  188. np.uint8,
  189. np.uint16,
  190. np.uint32,
  191. np.uint64,
  192. np.float16,
  193. np.float32,
  194. np.float64,
  195. )
  196. if HAS_NUMPY
  197. else ()
  198. )
  199. assert istype(
  200. val,
  201. (
  202. int,
  203. float,
  204. bool,
  205. type(None),
  206. str,
  207. type,
  208. list,
  209. tuple,
  210. set,
  211. slice,
  212. frozenset,
  213. range,
  214. torch.Size,
  215. torch.device,
  216. torch.dtype,
  217. )
  218. + np_types,
  219. ), t.__name__
  220. if istype(val, (torch.device, torch.dtype)):
  221. # TODO(jansel): is this slow? perhaps optimize it
  222. code = [f"str({ref}) == {str(val)!r}"]
  223. self._produce_guard_code(guard, code)
  224. return
  225. # Special case for nan because float("nan") == float("nan") evaluates to False
  226. if istype(val, float) and math.isnan(val):
  227. code = list()
  228. code.append(f"___check_type_id({ref}, {self.id_ref(t)})")
  229. code.append(f"__math_isnan({ref})")
  230. self._produce_guard_code(guard, code)
  231. return
  232. # Add type check to prevent equality check between tensor and non-tensor.
  233. code = list()
  234. if istype(val, (list, tuple)):
  235. self.LIST_LENGTH(guard)
  236. for idx, elem in enumerate(val):
  237. code.append(
  238. f"___check_type_id({ref}[{idx}], {self.id_ref(type(elem))})"
  239. )
  240. elif not istype(val, torch.Size):
  241. code.append(f"___check_type_id({ref}, {self.id_ref(t)})")
  242. if istype(val, torch.Size):
  243. val = tuple(val)
  244. code.append(f"{ref} == {val!r}")
  245. self._produce_guard_code(guard, code)
  246. def CONSTANT_MATCH(self, guard: Guard):
  247. val = self.get(guard.name)
  248. if istype(val, (bool, type(None))):
  249. self.ID_MATCH(guard)
  250. else:
  251. self.EQUALS_MATCH(guard)
  252. def NN_MODULE(self, guard: Guard):
  253. self.ID_MATCH(guard)
  254. ref = self.arg_ref(guard)
  255. val = self.get(guard.name)
  256. def setup_guard():
  257. assert istype(val.training, bool)
  258. self.code.append(f"{ref}.training == {val.training}")
  259. if hasattr(val, "training"):
  260. # There are cases where a monkeypatched object has a guard made between __new__ and __init__
  261. setup_guard()
  262. else:
  263. unimplemented(f"Guard setup for uninitialized class {type(val)}")
  264. def FUNCTION_MATCH(self, guard: Guard):
  265. """things like torch.add and user defined functions"""
  266. if guard.is_local():
  267. return self.ID_MATCH(guard)
  268. def BUILTIN_MATCH(self, guard: Guard):
  269. return self.FUNCTION_MATCH(guard)
  270. def PYMODULE_MATCH(self, guard: Guard):
  271. return self.FUNCTION_MATCH(guard)
  272. def LIST_LENGTH(self, guard):
  273. ref = self.arg_ref(guard)
  274. value = self.get(guard.name)
  275. t = type(value)
  276. code = list()
  277. code.append(f"___check_type_id({ref}, {self.id_ref(t)})")
  278. code.append(f"len({ref}) == {len(value)}")
  279. self._produce_guard_code(guard, code)
  280. def TUPLE_ITERATOR_LEN(self, guard):
  281. ref = self.arg_ref(guard)
  282. value = self.get(guard.name)
  283. t = type(value)
  284. code = list()
  285. code.append(f"___check_type_id({ref}, {self.id_ref(t)})")
  286. code.append(f"___tuple_iterator_len({ref}) == {tuple_iterator_len(value)}")
  287. self._produce_guard_code(guard, code)
  288. def DICT_KEYS(self, guard):
  289. ref = self.arg_ref(guard)
  290. value = self.get(guard.name)
  291. t = type(value)
  292. code = list()
  293. code.append(f"___check_type_id({ref}, {self.id_ref(t)})")
  294. param_key_ids = set(dict_param_key_ids(value))
  295. const_keys = set(dict_const_keys(value))
  296. const_keys_repr = dict_const_keys_repr(const_keys)
  297. if param_key_ids:
  298. code.append(f"___dict_param_key_ids({ref}) == {param_key_ids!r}")
  299. code.append(f"___dict_const_keys({ref}) == {const_keys_repr}")
  300. else:
  301. code.append(f"set({ref}.keys()) == {const_keys_repr}")
  302. self._produce_guard_code(guard, code)
  303. def WEAKREF_ALIVE(self, guard):
  304. self._produce_guard_code(guard, [f"{self.arg_ref(guard)} is not None"])
  305. def NN_MODULE_PARAM_NAMES(self, guard):
  306. ref = self.arg_ref(guard)
  307. value = self.get(guard.name)
  308. t = type(value)
  309. keys = {k for k, v in value.named_parameters()}
  310. code = list()
  311. code.append(f"___check_type_id({ref}, {self.id_ref(t)})")
  312. code.append(f"{{k for k, v in {ref}.named_parameters()}} == {keys!r}")
  313. self._produce_guard_code(guard, code)
  314. def ODICT_KEYS(self, guard):
  315. """OrderedDict keys match"""
  316. ref = self.arg_ref(guard)
  317. value = self.get(guard.name)
  318. t = type(value)
  319. code = list()
  320. code.append(f"___check_type_id({ref}, {self.id_ref(t)})")
  321. code.append(f"str({ref}.keys()) == {str(value.keys())!r}")
  322. self._produce_guard_code(guard, code)
  323. def OBJECT_MUTATION(self, guard: Guard):
  324. mutation_guard.watch(self.get(guard.name), self.check_fn_manager)
  325. def GRAD_MODE(self, guard: Guard):
  326. """Guard on the initial grad state"""
  327. assert guard.name == ""
  328. assert guard.source is GuardSource.GLOBAL
  329. code = None
  330. if convert_frame.initial_grad_state:
  331. code = "___is_grad_enabled()"
  332. else:
  333. code = "not ___is_grad_enabled()"
  334. self._produce_guard_code(guard, [code])
  335. def SHAPE_ENV(self, guard: Guard):
  336. # Let's handle ShapeEnv guards. To do this, we will resolve
  337. # shape variables to sources from tracked_fakes. This must happen after
  338. # tensor checks.
  339. assert guard.name == ""
  340. output_graph = self.check_fn_manager.output_graph
  341. # NB: self.output_graph can be None in the debug_nops tests
  342. fs = output_graph.tracked_fakes
  343. guards = output_graph.shape_env.produce_guards(
  344. [a.fake for a in fs],
  345. [a.source for a in fs],
  346. source_ref=self.source_ref,
  347. )
  348. for shape_guard in guards:
  349. self._produce_guard_code(guard, [shape_guard], shape_env=True)
  350. def TENSOR_MATCH(self, guard: Guard):
  351. if guard.is_nn_module():
  352. self.ID_MATCH(guard)
  353. else:
  354. value = self.get(guard.name)
  355. assert isinstance(value, torch.Tensor)
  356. tensor_name = self.arg_ref(guard)
  357. self.tensor_check_names.append(tensor_name)
  358. self.tensor_check_examples.append(value)
  359. # STOP - DO NOT USE id_ref FOR TENSORS - TENSOR INVALIDATION RULES DIFFER
  360. self.tensor_check_ids[tensor_name] = id(value)
  361. # Note: Guard code produced for tensor_match is a little different.
  362. # We accumulate tensor names, then do a single install of `___check_tensors`.
  363. # See _guards.cpp and TensorGuard for more information.
  364. # TODO(voz): Add tensor matching code to export
  365. # Note: this is a bit of a special case, and so does not use _produce_guard_code
  366. guard.set_export_info(
  367. "TENSOR_MATCH",
  368. weakref.ref(type(value)),
  369. None,
  370. weakref.ref(value),
  371. )
  372. # A util that appends guarded code, or, in the case of export, adds data onto guards
  373. def _produce_guard_code(
  374. self, guard, code_list, provided_guarded_object=None, shape_env=False
  375. ):
  376. # WARNING: It is important that cur_frame/caller do NOT stay in
  377. # the current frame, because they will keep things live longer
  378. # than they should. See TestMisc.test_release_module_memory
  379. cur_frame = currentframe()
  380. assert cur_frame is not None
  381. caller = cur_frame.f_back
  382. del cur_frame
  383. assert caller is not None
  384. func_name = getframeinfo(caller)[2]
  385. del caller
  386. # We use func_name for export, so might as well get a nice defensive check out of it
  387. assert func_name in dir(
  388. self.__class__
  389. ), f"_produce_guard_code must be called from inside GuardedCode. Called from {func_name}"
  390. if shape_env:
  391. self.shape_env_code.extend(code_list)
  392. else:
  393. self.code.extend(code_list)
  394. # Not all guards have names, some can be installed globally (see asserts on HAS_GRAD)
  395. if provided_guarded_object is None:
  396. name_valid = guard.name is not None and guard.name != ""
  397. guarded_object = self.get(guard.name) if name_valid else None
  398. else:
  399. guarded_object = provided_guarded_object
  400. guarded_object_type = (
  401. weakref.ref(type(guarded_object)) if guarded_object is not None else None
  402. )
  403. obj_ref = None
  404. if hasattr(guarded_object.__class__, "__weakref__"):
  405. obj_ref = weakref.ref(guarded_object)
  406. guard.set_export_info(
  407. func_name,
  408. guarded_object_type,
  409. code_list,
  410. obj_ref,
  411. )
  412. # NB: Naively, you'd expect this to only be a function that produces
  413. # the callable that consistutes the guard. However, there is some
  414. # delicate handling for invalidating this check function when the
  415. # locals/globals get invalidated, so there's some extra state
  416. # we have to hold in this manager class.
  417. #
  418. # TODO: this object has reference cycle with itself, via check_fn which
  419. # references back to CheckFunction via ___guarded_code in closure_vars.
  420. # Ideally, there shouldn't be any ref cycle so that guards are
  421. # promptly disposed of.
  422. class CheckFunctionManager:
  423. def __init__(
  424. self,
  425. output_graph=None,
  426. f_locals: Optional[Dict[str, object]] = None,
  427. f_globals: Optional[Dict[str, object]] = None,
  428. guard_fail_fn: Optional[Callable[[Tuple[str, str]], None]] = None,
  429. ):
  430. guards = output_graph.guards if output_graph else None
  431. self.valid = True
  432. self._weakrefs: List["ReferenceType[object]"] = []
  433. self._seen_ids: Set[int] = set()
  434. self.output_graph = output_graph
  435. # Note: right overrides left
  436. def combine_scopes(left, right):
  437. if left is None:
  438. return right
  439. if right is None:
  440. return left
  441. return {**left, **right}
  442. def source_ref(source):
  443. guard_source = source.guard_source()
  444. if guard_source is GuardSource.CONSTANT:
  445. # No need to track constants
  446. return source.name()
  447. builder = guard_source.select(w_local(), w_global())
  448. assert builder is not None
  449. return builder.arg_ref(source.name())
  450. local_builder = GuardBuilder(
  451. self.id_ref,
  452. source_ref,
  453. combine_scopes(f_globals, f_locals),
  454. self,
  455. renames=True,
  456. )
  457. global_builder = GuardBuilder(
  458. self.id_ref, source_ref, f_globals, self, renames=False
  459. )
  460. # source_ref can cause a cycle, make sure we break it with weakref
  461. w_local = weakref.ref(local_builder)
  462. w_global = weakref.ref(global_builder)
  463. for guard in sorted(guards or [], key=Guard.sort_key):
  464. if (
  465. not config.guard_nn_modules
  466. and guard.is_nn_module()
  467. # Default func args must be guarded on.
  468. # TODO: we could make use of 'DefaultsSource' and offer a .guard.is_defaults() API
  469. and "__defaults__" not in guard.name
  470. and "__kwdefaults__" not in guard.name
  471. ):
  472. continue
  473. guard.create(local_builder, global_builder)
  474. self.check_fn = self.compile_check_fn(
  475. local_builder, global_builder, guards, guard_fail_fn
  476. )
  477. self._seen_ids.clear()
  478. def compile_check_fn(
  479. self, local_builder, global_builder, guards_out, guard_fail_fn
  480. ):
  481. assert not (set(local_builder.argnames) & set(global_builder.argnames))
  482. # see parallel handling of ".0" / "___implicit0" in _eval_frame.c
  483. largs = [a for a in local_builder.scope.keys() if a == "___implicit0"]
  484. largs += [a for a in local_builder.argnames if a != "___implicit0"]
  485. largs += ["**___kwargs_ignored"]
  486. args = ",".join(largs)
  487. code_parts = (
  488. ["___guarded_code.valid"] + local_builder.code + global_builder.code
  489. )
  490. # TODO(whc) maybe only the 'check_tensors' one is ambiguous? if so we can be less general..
  491. verbose_code_parts = (
  492. ["___guarded_code.valid"] + local_builder.code + global_builder.code
  493. )
  494. tensor_check_names = (
  495. local_builder.tensor_check_names + global_builder.tensor_check_names
  496. )
  497. tensor_check_ids = local_builder.tensor_check_ids.copy()
  498. tensor_check_ids.update(global_builder.tensor_check_ids)
  499. check_tensors_fn = None
  500. check_tensors_verbose_fn = None
  501. if tensor_check_names:
  502. tensor_check_examples = (
  503. local_builder.tensor_check_examples
  504. + global_builder.tensor_check_examples
  505. )
  506. tensor_guards = TensorGuards(
  507. *tensor_check_examples, dynamic_shapes=config.dynamic_shapes
  508. )
  509. check_tensors_fn = tensor_guards.check
  510. check_tensors_verbose_fn = tensor_guards.check_verbose
  511. code_parts.append(f"___check_tensors({', '.join(tensor_check_names)})")
  512. verbose_args = ", ".join(
  513. tensor_check_names + ["tensor_check_names=tensor_check_names"]
  514. )
  515. verbose_code_parts.append(f"___check_tensors_verbose({verbose_args})")
  516. aotautograd_guards: List[GuardEnvExpr] = (
  517. self.output_graph.tracing_context.guards_context.aotautograd_guards
  518. if self.output_graph
  519. else []
  520. )
  521. for guard in aotautograd_guards:
  522. if isinstance(guard, DuplicateInputs):
  523. pos_a = self.output_graph.pos_to_arg[guard.input_pos_a]
  524. pos_b = self.output_graph.pos_to_arg[guard.input_pos_b]
  525. assert (
  526. pos_b >= 0 and pos_a >= 0
  527. ), "Deduped args out of bounds, cannot be negative"
  528. assert self.output_graph.graphargs[
  529. pos_a
  530. ].is_tensor, "Deduped arg must be a tensor"
  531. assert self.output_graph.graphargs[
  532. pos_b
  533. ].is_tensor, "Deduped arg must be a tensor"
  534. code_part = f"{self.output_graph.graphargs[pos_a].source.name()} is {self.output_graph.graphargs[pos_b].source.name()}" # noqa: B950
  535. code_parts.append(code_part)
  536. verbose_code_parts.append(code_part)
  537. else:
  538. raise RuntimeError(f"Unknown GuardEnvExpr: {guard}")
  539. code_parts.extend(local_builder.shape_env_code)
  540. verbose_code_parts.extend(local_builder.shape_env_code)
  541. assert not global_builder.shape_env_code
  542. code = " and ".join(unique(code_parts))
  543. closure_vars = collections.OrderedDict(
  544. [
  545. ("___guarded_code", self),
  546. ("___check_tensors", check_tensors_fn),
  547. ("___check_tensors_verbose", check_tensors_verbose_fn),
  548. ("tensor_check_names", tensor_check_names),
  549. ]
  550. + list(SYMPY_INTERP.items())
  551. )
  552. closure_vars.update(CLOSURE_VARS)
  553. py_code = f"""\
  554. def ___make_guard_fn({','.join(closure_vars.keys())}):
  555. return lambda {args}: {code}
  556. """
  557. if os.environ.get("TORCHDYNAMO_PRINT_GUARDS", None) == "1":
  558. print("GUARDS", code)
  559. set_guard_fail_hook(guard_fail_hook)
  560. out: Dict[str, Any] = dict()
  561. # print("RUNNING PY CODE", py_code)
  562. exec(py_code, global_builder.scope, out)
  563. guard_fn = out["___make_guard_fn"](*closure_vars.values())
  564. guard_fn.closure_vars = closure_vars
  565. # TODO(whc) maybe '.code_parts' was only kept around for the guard callback? so we don't need both
  566. guard_fn.args = largs
  567. guard_fn.code_parts = code_parts
  568. guard_fn.verbose_code_parts = verbose_code_parts
  569. guard_fn.global_scope = global_builder.scope
  570. guard_fn.guard_fail_fn = guard_fail_fn
  571. return guard_fn
  572. def invalidate(self, ref):
  573. # A weakref is no longer valid, self.check_fn should return false
  574. self.valid = False
  575. def id_ref(self, obj):
  576. """add a weakref, return the id"""
  577. try:
  578. if id(obj) not in self._seen_ids:
  579. self._weakrefs.append(weakref.ref(obj, self.invalidate))
  580. self._seen_ids.add(id(obj))
  581. except TypeError:
  582. pass # cannot weakref bool object
  583. return id(obj)
  584. def guard_fail_hook(
  585. guard_fn: GuardFn, code: types.CodeType, f_locals: Dict[str, object], last: bool
  586. ) -> None:
  587. """
  588. called whenever a guard fails.
  589. """
  590. if not guard_fn.guard_fail_fn and not last:
  591. return
  592. scope = {rename_implicit(k): v for k, v in f_locals.items()}
  593. scope.update(guard_fn.closure_vars)
  594. reason = None
  595. for part in guard_fn.verbose_code_parts:
  596. fail_reason = eval(part, guard_fn.global_scope, scope)
  597. # TODO(whc) hacky for now as not every 'part' in guard_fn.verbose_code_parts
  598. # is updated to return a string explaining the failure.
  599. if isinstance(fail_reason, str):
  600. reason = fail_reason
  601. break
  602. elif isinstance(fail_reason, bool) and not fail_reason:
  603. reason = part
  604. break
  605. try:
  606. if guard_fn.guard_fail_fn is not None:
  607. guard_fn.guard_fail_fn(
  608. GuardFail(reason or "unknown reason", orig_code_map[code])
  609. )
  610. except Exception as e:
  611. log.error(
  612. "Failure in guard_fail_fn callback - raising here will cause a NULL Error on guard eval",
  613. exc_info=True,
  614. )
  615. if last:
  616. guard_failures[orig_code_map[code]].append(reason)
  617. def guard_error_hook(
  618. guard_fn: GuardFn, code: types.CodeType, f_locals: Dict[str, object], last: bool
  619. ):
  620. print(
  621. f"ERROR RUNNING GUARDS {code.co_name} {code.co_filename}:{code.co_firstlineno}"
  622. )
  623. # TODO: If we passed in the exception here, we could get a precise
  624. # column number of which subexpression failed. But that would also
  625. # require us to have the TRUE code that was eval'ed, not a shoddy
  626. # reconstruction (like is done here)
  627. print("lambda " + ", ".join(guard_fn.args) + ":")
  628. print(" ", " and\n ".join(guard_fn.code_parts))
  629. set_guard_error_hook(guard_error_hook)
  630. def unique(seq):
  631. seen = set()
  632. for x in seq:
  633. if x not in seen:
  634. yield x
  635. seen.add(x)