allowed_functions.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. import builtins
  2. import collections
  3. import copy
  4. import functools
  5. import inspect
  6. import itertools
  7. import math
  8. import operator
  9. import types
  10. import warnings
  11. from typing import Dict, Optional, Set
  12. import torch
  13. from torch.fx._symbolic_trace import is_fx_tracing
  14. from . import config
  15. from .external_utils import is_compiling
  16. from .utils import HAS_NUMPY, is_safe_constant, np
  17. """
  18. A note on allowed functions:
  19. Dynamo consults this file to determine if a particular function/module
  20. is allowed to appear as a node in its fx output.
  21. If a function is disallowed, it may either be traced-through, or skipped.
  22. Trace-through means dynamo will continue to trace the interior code for
  23. the function/module rather than stopping at its boundary and recording it
  24. as a node in the fx graph. Whether tracing through or allowing, the functionality
  25. of the function/module is part of the dynamo graph. Caveat: if tracing through,
  26. any interior operation could trigger its own graph-break.
  27. Skips are determined by (torch/_dynamo/skipfiles.py) - see "a note on
  28. skipfiles" there.
  29. """
  30. def make_function_id_set(lazy_initializer):
  31. """
  32. Track a set of `id()`s of objects which are either allowed or not
  33. allowed to go into the generated FX graph. Use to test for torch.*,
  34. numpy.*, builtins.*, etc.
  35. Support user modification to permit customization of what can be
  36. added to the graph and what will cause a graph break.
  37. """
  38. class FunctionIdSet:
  39. function_ids: Optional[Set[int]] = None
  40. function_names: Optional[Dict[int, str]] = None
  41. def __call__(self):
  42. if self.function_ids is None:
  43. value = lazy_initializer()
  44. if isinstance(value, dict):
  45. self.function_ids = set(value.keys())
  46. self.function_names = value
  47. else:
  48. assert isinstance(value, set)
  49. self.function_ids = value
  50. return self.function_ids
  51. def get_name(self, idx: int, default: str):
  52. self() # lazy init
  53. return self.function_names.get(idx, default)
  54. def add(self, idx: int):
  55. self() # lazy init
  56. self.function_ids.add(idx)
  57. def remove(self, idx: int):
  58. if idx in self():
  59. self.function_ids.remove(idx)
  60. def __contains__(self, idx: int):
  61. return idx in self()
  62. return FunctionIdSet()
  63. @make_function_id_set
  64. def _disallowed_function_ids():
  65. remove = [
  66. True,
  67. False,
  68. None,
  69. collections.OrderedDict,
  70. copy.copy,
  71. copy.deepcopy,
  72. inspect.signature,
  73. math.__package__,
  74. torch.__builtins__,
  75. torch.autocast_decrement_nesting,
  76. torch.autocast_increment_nesting,
  77. torch.autograd.grad,
  78. torch.clear_autocast_cache,
  79. torch.cuda.current_device,
  80. torch.cuda.amp.autocast_mode.autocast,
  81. torch.cpu.amp.autocast_mode.autocast,
  82. torch.distributions.constraints.is_dependent,
  83. torch.distributions.normal.Normal,
  84. torch.inference_mode,
  85. torch.set_anomaly_enabled,
  86. torch.set_autocast_cache_enabled,
  87. torch.set_autocast_cpu_dtype,
  88. torch.set_autocast_cpu_enabled,
  89. torch.set_autocast_enabled,
  90. torch.set_autocast_gpu_dtype,
  91. torch.autograd.profiler.profile,
  92. warnings.warn,
  93. torch._C._dynamo.eval_frame.unsupported,
  94. ]
  95. # extract all dtypes from torch
  96. dtypes = [
  97. obj for obj in torch.__dict__.values() if isinstance(obj, type(torch.float32))
  98. ]
  99. remove += dtypes
  100. storage = [
  101. obj
  102. for obj in torch.__dict__.values()
  103. if isinstance(obj, type(torch.FloatStorage))
  104. ]
  105. remove += storage
  106. return {id(x) for x in remove}
  107. @make_function_id_set
  108. def _allowed_function_ids():
  109. """
  110. Walk torch.* and get the ids of all the stuff in it
  111. """
  112. warnings.filterwarnings("ignore", category=UserWarning, module="torch.distributed")
  113. torch_object_ids = dict()
  114. def _is_allowed_module_prefix(obj):
  115. allowed_modules = ("torch", "math")
  116. # torch.nn.modules.rnn is disallowed because these modules internally
  117. # flatten their parameters. This flattening process will call
  118. # Tensor.set_ with a Storage, and Storages cannot be traced with
  119. # AOTAutograd; so we need to graph-break. To ensure this, we inline
  120. # these functions, rather than keep them opaque-ly in the graph.
  121. disallowed_modules = (
  122. "torch.optim.",
  123. "torch.nn.modules.rnn.",
  124. "torch._dynamo.",
  125. "torch._C._dynamo.",
  126. "torch._inductor.",
  127. "torch._C.inductor.",
  128. "torch.fx.",
  129. "torch.distributed.fsdp.",
  130. )
  131. allowed_modules_dot = tuple([x + "." for x in allowed_modules])
  132. module = inspect.getmodule(obj)
  133. if module is None:
  134. return False
  135. mod_name = module.__name__
  136. if any(mod_name.startswith(m) for m in disallowed_modules):
  137. return False
  138. return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot)
  139. def _find_torch_objects(module):
  140. if any(
  141. module.__name__.startswith(mod_name)
  142. for mod_name in config.allowed_functions_module_string_ignorelist
  143. ):
  144. return
  145. torch_object_ids[id(module)] = module.__name__
  146. for name, obj in list(module.__dict__.items()):
  147. if id(obj) not in torch_object_ids:
  148. if isinstance(obj, types.ModuleType):
  149. if obj.__name__.startswith("torch.") and _is_allowed_module_prefix(
  150. obj
  151. ):
  152. torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
  153. _find_torch_objects(obj)
  154. elif _is_allowed_module_prefix(obj):
  155. torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
  156. elif inspect.getmodule(obj) is None and not is_safe_constant(obj):
  157. torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
  158. _find_torch_objects(torch)
  159. _find_torch_objects(math)
  160. # torch.Tensor.{fn}
  161. for name in dir(torch.Tensor):
  162. method = getattr(torch.Tensor, name)
  163. if isinstance(method, types.MethodDescriptorType):
  164. torch_object_ids[id(method)] = f"torch.Tensor.{name}"
  165. for idx in _disallowed_function_ids():
  166. if idx in torch_object_ids:
  167. del torch_object_ids[idx]
  168. for extra in (is_fx_tracing, is_compiling):
  169. torch_object_ids[id(extra)] = f"{extra.__module__}.{extra.__name__}"
  170. return torch_object_ids
  171. @make_function_id_set
  172. def _builtin_function_ids():
  173. rv = {
  174. id(v): f"builtins.{k}"
  175. for k, v in builtins.__dict__.items()
  176. if not k.startswith("_") and callable(v)
  177. }
  178. rv.update(
  179. {
  180. id(v): f"operator.{k}"
  181. for k, v in operator.__dict__.items()
  182. if not k.startswith("_") and callable(v)
  183. }
  184. )
  185. rv.update(
  186. {id(v): f"functools.{v.__name__}" for v in (itertools.chain, itertools.islice)}
  187. )
  188. rv[id(functools.reduce)] = "functools.reduce"
  189. return rv
  190. @make_function_id_set
  191. def _numpy_function_ids():
  192. rv = dict()
  193. if HAS_NUMPY:
  194. for mod in (np, np.random):
  195. rv.update(
  196. {
  197. id(v): f"{mod.__name__}.{k}"
  198. for k, v in mod.__dict__.items()
  199. if callable(v)
  200. and (getattr(v, "__module__", None) or mod.__name__) == mod.__name__
  201. }
  202. )
  203. return rv
  204. @make_function_id_set
  205. def _builtin_constant_ids():
  206. """
  207. Collects constant builtins by eliminating callable items.
  208. """
  209. rv = {
  210. id(v): f"builtins.{k}"
  211. for k, v in builtins.__dict__.items()
  212. if not k.startswith("_") and not callable(v)
  213. }
  214. return rv
  215. def is_allowed(obj):
  216. """Is this safe to trace like torch.add ?"""
  217. # torch.ops is populated lazily so we don't necessarily have them in
  218. # _allowed_function_ids. Figure it out by testing the type instead
  219. # in those cases
  220. return id(obj) in _allowed_function_ids or isinstance(
  221. obj,
  222. (torch._ops.OpOverloadPacket, torch._ops.OpOverload, torch._ops._OpNamespace),
  223. )
  224. def torch_get_name(obj, default):
  225. """Convert a torch.* funcion to a string"""
  226. return _allowed_function_ids.get_name(id(obj), default)
  227. def is_builtin_callable(obj):
  228. return id(obj) in _builtin_function_ids
  229. def is_builtin_constant(obj):
  230. return id(obj) in _builtin_constant_ids
  231. def is_numpy(obj):
  232. if HAS_NUMPY:
  233. return isinstance(obj, np.ndarray) or id(obj) in _numpy_function_ids
  234. else:
  235. return False