123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229 |
- import _collections_abc
- import _weakrefset
- import abc
- import collections
- import contextlib
- import copy
- import copyreg
- import dataclasses
- import enum
- import functools
- import importlib
- import inspect
- import linecache
- import logging
- import multiprocessing
- import operator
- import os
- import posixpath
- import random
- import re
- import selectors
- import signal
- import tempfile
- import threading
- import tokenize
- import traceback
- import types
- import typing
- import unittest
- import weakref
- import torch
- import torch._inductor.test_operators
- try:
- import torch._prims
- # isort: split
- # TODO: Hack to unblock simultaneous landing changes. Fix after https://github.com/pytorch/pytorch/pull/81088 lands
- import torch._prims.utils
- import torch._prims.wrappers
- import torch._refs
- import torch._refs.nn
- import torch._refs.nn.functional
- import torch._refs.special
- HAS_PRIMS_REFS = True
- except ImportError:
- HAS_PRIMS_REFS = False
- from . import comptime, config, external_utils
- """
- A note on skipfiles:
- Dynamo consults this file to determine whether code should be compiled or skipped.
- A skip applies at the frame boundary, meaning dynamo either triggers a graph break
- at the beginning of the frame or attempts to trace the whole frame. When skipping
- a frame, recursively called frames are still traced by dynamo unless also skipped.
- Skipfiles (skipped at the file level instead of function level) still apply on a
- frame-by-frame boundary as dynamo traces, but apply to all functions in that file.
- @skip is a helper decorator that can be applied to your function to cause it to be
- included here.
- """
- def _strip_init_py(s):
- return re.sub(r"__init__.py$", "", s)
- def _module_dir(m: types.ModuleType):
- return _strip_init_py(m.__file__)
- SKIP_DIRS = [
- # torch.*
- _module_dir(torch),
- # torchdynamo.*
- os.path.dirname(__file__) + "/",
- "<frozen importlib",
- "<__array_function__ internals>",
- ] + [
- # skip some standard libs
- _module_dir(m)
- for m in (
- abc,
- collections,
- contextlib,
- copy,
- copyreg,
- dataclasses,
- enum,
- functools,
- importlib,
- inspect,
- linecache,
- logging,
- multiprocessing,
- operator,
- os,
- posixpath,
- random,
- re,
- selectors,
- signal,
- tempfile,
- threading,
- tokenize,
- traceback,
- types,
- typing,
- unittest,
- weakref,
- _collections_abc,
- _weakrefset,
- )
- ]
- FILENAME_ALLOWLIST = {
- torch.nn.Sequential.__init__.__code__.co_filename,
- torch.set_rng_state.__code__.co_filename,
- torch._inductor.test_operators.__file__,
- # These are dynamo files!
- external_utils.__file__,
- comptime.__file__, # Want to inline these helpers
- }
- # Include optimizer code for tracing
- FILENAME_ALLOWLIST |= {
- inspect.getfile(obj)
- for obj in torch.optim.__dict__.values()
- if inspect.isclass(obj)
- }
- FILENAME_ALLOWLIST |= {torch.optim._functional.__file__}
- if HAS_PRIMS_REFS:
- FILENAME_ALLOWLIST |= {
- torch._prims.__file__,
- torch._prims.utils.__file__,
- torch._prims.wrappers.__file__,
- torch._refs.__file__,
- torch._refs.special.__file__,
- torch._refs.nn.functional.__file__,
- }
- SKIP_DIRS_RE = None
- def _recompile_re():
- global SKIP_DIRS_RE
- SKIP_DIRS_RE = re.compile(f"^({'|'.join(map(re.escape, SKIP_DIRS))})")
- def add(import_name: str):
- if isinstance(import_name, types.ModuleType):
- return add(import_name.__name__)
- assert isinstance(import_name, str)
- module_spec = importlib.util.find_spec(import_name)
- if not module_spec:
- return
- origin = module_spec.origin
- if origin is None:
- return
- global SKIP_DIRS_RE
- SKIP_DIRS.append(_strip_init_py(origin))
- _recompile_re()
- def check(filename, allow_torch=False):
- """Should skip this file?"""
- if filename is None:
- return True
- if filename in FILENAME_ALLOWLIST:
- return False
- if allow_torch and is_torch(filename):
- return False
- return bool(SKIP_DIRS_RE.match(filename))
- # skip common third party libs
- for _name in (
- "functorch",
- "intel_extension_for_pytorch",
- "networkx",
- "numpy",
- "omegaconf",
- "onnx",
- "onnxruntime",
- "onnx_tf",
- "pandas",
- "sklearn",
- "tabulate",
- "tensorflow",
- "tensorrt",
- "torch2trt",
- "tqdm",
- "tree",
- "tvm",
- "fx2trt_oss",
- "xarray",
- ):
- add(_name)
- _recompile_re()
- def is_torch_inline_allowed(filename):
- return any(
- filename.startswith(_module_dir(mod))
- for mod in config.skipfiles_inline_module_allowlist
- )
- @functools.lru_cache(None)
- def dynamo_dir():
- import torch._dynamo
- return _module_dir(torch._dynamo)
- def is_torch(filename):
- if filename.startswith(dynamo_dir()):
- return False
- return filename.startswith(_module_dir(torch))
|