skipfiles.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. import _collections_abc
  2. import _weakrefset
  3. import abc
  4. import collections
  5. import contextlib
  6. import copy
  7. import copyreg
  8. import dataclasses
  9. import enum
  10. import functools
  11. import importlib
  12. import inspect
  13. import linecache
  14. import logging
  15. import multiprocessing
  16. import operator
  17. import os
  18. import posixpath
  19. import random
  20. import re
  21. import selectors
  22. import signal
  23. import tempfile
  24. import threading
  25. import tokenize
  26. import traceback
  27. import types
  28. import typing
  29. import unittest
  30. import weakref
  31. import torch
  32. import torch._inductor.test_operators
  33. try:
  34. import torch._prims
  35. # isort: split
  36. # TODO: Hack to unblock simultaneous landing changes. Fix after https://github.com/pytorch/pytorch/pull/81088 lands
  37. import torch._prims.utils
  38. import torch._prims.wrappers
  39. import torch._refs
  40. import torch._refs.nn
  41. import torch._refs.nn.functional
  42. import torch._refs.special
  43. HAS_PRIMS_REFS = True
  44. except ImportError:
  45. HAS_PRIMS_REFS = False
  46. from . import comptime, config, external_utils
  47. """
  48. A note on skipfiles:
  49. Dynamo consults this file to determine whether code should be compiled or skipped.
  50. A skip applies at the frame boundary, meaning dynamo either triggers a graph break
  51. at the beginning of the frame or attempts to trace the whole frame. When skipping
  52. a frame, recursively called frames are still traced by dynamo unless also skipped.
  53. Skipfiles (skipped at the file level instead of function level) still apply on a
  54. frame-by-frame boundary as dynamo traces, but apply to all functions in that file.
  55. @skip is a helper decorator that can be applied to your function to cause it to be
  56. included here.
  57. """
  58. def _strip_init_py(s):
  59. return re.sub(r"__init__.py$", "", s)
  60. def _module_dir(m: types.ModuleType):
  61. return _strip_init_py(m.__file__)
  62. SKIP_DIRS = [
  63. # torch.*
  64. _module_dir(torch),
  65. # torchdynamo.*
  66. os.path.dirname(__file__) + "/",
  67. "<frozen importlib",
  68. "<__array_function__ internals>",
  69. ] + [
  70. # skip some standard libs
  71. _module_dir(m)
  72. for m in (
  73. abc,
  74. collections,
  75. contextlib,
  76. copy,
  77. copyreg,
  78. dataclasses,
  79. enum,
  80. functools,
  81. importlib,
  82. inspect,
  83. linecache,
  84. logging,
  85. multiprocessing,
  86. operator,
  87. os,
  88. posixpath,
  89. random,
  90. re,
  91. selectors,
  92. signal,
  93. tempfile,
  94. threading,
  95. tokenize,
  96. traceback,
  97. types,
  98. typing,
  99. unittest,
  100. weakref,
  101. _collections_abc,
  102. _weakrefset,
  103. )
  104. ]
  105. FILENAME_ALLOWLIST = {
  106. torch.nn.Sequential.__init__.__code__.co_filename,
  107. torch.set_rng_state.__code__.co_filename,
  108. torch._inductor.test_operators.__file__,
  109. # These are dynamo files!
  110. external_utils.__file__,
  111. comptime.__file__, # Want to inline these helpers
  112. }
  113. # Include optimizer code for tracing
  114. FILENAME_ALLOWLIST |= {
  115. inspect.getfile(obj)
  116. for obj in torch.optim.__dict__.values()
  117. if inspect.isclass(obj)
  118. }
  119. FILENAME_ALLOWLIST |= {torch.optim._functional.__file__}
  120. if HAS_PRIMS_REFS:
  121. FILENAME_ALLOWLIST |= {
  122. torch._prims.__file__,
  123. torch._prims.utils.__file__,
  124. torch._prims.wrappers.__file__,
  125. torch._refs.__file__,
  126. torch._refs.special.__file__,
  127. torch._refs.nn.functional.__file__,
  128. }
  129. SKIP_DIRS_RE = None
  130. def _recompile_re():
  131. global SKIP_DIRS_RE
  132. SKIP_DIRS_RE = re.compile(f"^({'|'.join(map(re.escape, SKIP_DIRS))})")
  133. def add(import_name: str):
  134. if isinstance(import_name, types.ModuleType):
  135. return add(import_name.__name__)
  136. assert isinstance(import_name, str)
  137. module_spec = importlib.util.find_spec(import_name)
  138. if not module_spec:
  139. return
  140. origin = module_spec.origin
  141. if origin is None:
  142. return
  143. global SKIP_DIRS_RE
  144. SKIP_DIRS.append(_strip_init_py(origin))
  145. _recompile_re()
  146. def check(filename, allow_torch=False):
  147. """Should skip this file?"""
  148. if filename is None:
  149. return True
  150. if filename in FILENAME_ALLOWLIST:
  151. return False
  152. if allow_torch and is_torch(filename):
  153. return False
  154. return bool(SKIP_DIRS_RE.match(filename))
  155. # skip common third party libs
  156. for _name in (
  157. "functorch",
  158. "intel_extension_for_pytorch",
  159. "networkx",
  160. "numpy",
  161. "omegaconf",
  162. "onnx",
  163. "onnxruntime",
  164. "onnx_tf",
  165. "pandas",
  166. "sklearn",
  167. "tabulate",
  168. "tensorflow",
  169. "tensorrt",
  170. "torch2trt",
  171. "tqdm",
  172. "tree",
  173. "tvm",
  174. "fx2trt_oss",
  175. "xarray",
  176. ):
  177. add(_name)
  178. _recompile_re()
  179. def is_torch_inline_allowed(filename):
  180. return any(
  181. filename.startswith(_module_dir(mod))
  182. for mod in config.skipfiles_inline_module_allowlist
  183. )
  184. @functools.lru_cache(None)
  185. def dynamo_dir():
  186. import torch._dynamo
  187. return _module_dir(torch._dynamo)
  188. def is_torch(filename):
  189. if filename.startswith(dynamo_dir()):
  190. return False
  191. return filename.startswith(_module_dir(torch))