123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282 |
- import builtins
- import collections
- import copy
- import functools
- import inspect
- import itertools
- import math
- import operator
- import types
- import warnings
- from typing import Dict, Optional, Set
- import torch
- from torch.fx._symbolic_trace import is_fx_tracing
- from . import config
- from .external_utils import is_compiling
- from .utils import HAS_NUMPY, is_safe_constant, np
- """
- A note on allowed functions:
- Dynamo consults this file to determine if a particular function/module
- is allowed to appear as a node in its fx output.
- If a function is disallowed, it may either be traced-through, or skipped.
- Trace-through means dynamo will continue to trace the interior code for
- the function/module rather than stopping at its boundary and recording it
- as a node in the fx graph. Whether tracing through or allowing, the functionality
- of the function/module is part of the dynamo graph. Caveat: if tracing through,
- any interior operation could trigger its own graph-break.
- Skips are determined by (torch/_dynamo/skipfiles.py) - see "a note on
- skipfiles" there.
- """
- def make_function_id_set(lazy_initializer):
- """
- Track a set of `id()`s of objects which are either allowed or not
- allowed to go into the generated FX graph. Use to test for torch.*,
- numpy.*, builtins.*, etc.
- Support user modification to permit customization of what can be
- added to the graph and what will cause a graph break.
- """
- class FunctionIdSet:
- function_ids: Optional[Set[int]] = None
- function_names: Optional[Dict[int, str]] = None
- def __call__(self):
- if self.function_ids is None:
- value = lazy_initializer()
- if isinstance(value, dict):
- self.function_ids = set(value.keys())
- self.function_names = value
- else:
- assert isinstance(value, set)
- self.function_ids = value
- return self.function_ids
- def get_name(self, idx: int, default: str):
- self() # lazy init
- return self.function_names.get(idx, default)
- def add(self, idx: int):
- self() # lazy init
- self.function_ids.add(idx)
- def remove(self, idx: int):
- if idx in self():
- self.function_ids.remove(idx)
- def __contains__(self, idx: int):
- return idx in self()
- return FunctionIdSet()
- @make_function_id_set
- def _disallowed_function_ids():
- remove = [
- True,
- False,
- None,
- collections.OrderedDict,
- copy.copy,
- copy.deepcopy,
- inspect.signature,
- math.__package__,
- torch.__builtins__,
- torch.autocast_decrement_nesting,
- torch.autocast_increment_nesting,
- torch.autograd.grad,
- torch.clear_autocast_cache,
- torch.cuda.current_device,
- torch.cuda.amp.autocast_mode.autocast,
- torch.cpu.amp.autocast_mode.autocast,
- torch.distributions.constraints.is_dependent,
- torch.distributions.normal.Normal,
- torch.inference_mode,
- torch.set_anomaly_enabled,
- torch.set_autocast_cache_enabled,
- torch.set_autocast_cpu_dtype,
- torch.set_autocast_cpu_enabled,
- torch.set_autocast_enabled,
- torch.set_autocast_gpu_dtype,
- torch.autograd.profiler.profile,
- warnings.warn,
- torch._C._dynamo.eval_frame.unsupported,
- ]
- # extract all dtypes from torch
- dtypes = [
- obj for obj in torch.__dict__.values() if isinstance(obj, type(torch.float32))
- ]
- remove += dtypes
- storage = [
- obj
- for obj in torch.__dict__.values()
- if isinstance(obj, type(torch.FloatStorage))
- ]
- remove += storage
- return {id(x) for x in remove}
- @make_function_id_set
- def _allowed_function_ids():
- """
- Walk torch.* and get the ids of all the stuff in it
- """
- warnings.filterwarnings("ignore", category=UserWarning, module="torch.distributed")
- torch_object_ids = dict()
- def _is_allowed_module_prefix(obj):
- allowed_modules = ("torch", "math")
- # torch.nn.modules.rnn is disallowed because these modules internally
- # flatten their parameters. This flattening process will call
- # Tensor.set_ with a Storage, and Storages cannot be traced with
- # AOTAutograd; so we need to graph-break. To ensure this, we inline
- # these functions, rather than keep them opaque-ly in the graph.
- disallowed_modules = (
- "torch.optim.",
- "torch.nn.modules.rnn.",
- "torch._dynamo.",
- "torch._C._dynamo.",
- "torch._inductor.",
- "torch._C.inductor.",
- "torch.fx.",
- "torch.distributed.fsdp.",
- )
- allowed_modules_dot = tuple([x + "." for x in allowed_modules])
- module = inspect.getmodule(obj)
- if module is None:
- return False
- mod_name = module.__name__
- if any(mod_name.startswith(m) for m in disallowed_modules):
- return False
- return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot)
- def _find_torch_objects(module):
- if any(
- module.__name__.startswith(mod_name)
- for mod_name in config.allowed_functions_module_string_ignorelist
- ):
- return
- torch_object_ids[id(module)] = module.__name__
- for name, obj in list(module.__dict__.items()):
- if id(obj) not in torch_object_ids:
- if isinstance(obj, types.ModuleType):
- if obj.__name__.startswith("torch.") and _is_allowed_module_prefix(
- obj
- ):
- torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
- _find_torch_objects(obj)
- elif _is_allowed_module_prefix(obj):
- torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
- elif inspect.getmodule(obj) is None and not is_safe_constant(obj):
- torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
- _find_torch_objects(torch)
- _find_torch_objects(math)
- # torch.Tensor.{fn}
- for name in dir(torch.Tensor):
- method = getattr(torch.Tensor, name)
- if isinstance(method, types.MethodDescriptorType):
- torch_object_ids[id(method)] = f"torch.Tensor.{name}"
- for idx in _disallowed_function_ids():
- if idx in torch_object_ids:
- del torch_object_ids[idx]
- for extra in (is_fx_tracing, is_compiling):
- torch_object_ids[id(extra)] = f"{extra.__module__}.{extra.__name__}"
- return torch_object_ids
- @make_function_id_set
- def _builtin_function_ids():
- rv = {
- id(v): f"builtins.{k}"
- for k, v in builtins.__dict__.items()
- if not k.startswith("_") and callable(v)
- }
- rv.update(
- {
- id(v): f"operator.{k}"
- for k, v in operator.__dict__.items()
- if not k.startswith("_") and callable(v)
- }
- )
- rv.update(
- {id(v): f"functools.{v.__name__}" for v in (itertools.chain, itertools.islice)}
- )
- rv[id(functools.reduce)] = "functools.reduce"
- return rv
- @make_function_id_set
- def _numpy_function_ids():
- rv = dict()
- if HAS_NUMPY:
- for mod in (np, np.random):
- rv.update(
- {
- id(v): f"{mod.__name__}.{k}"
- for k, v in mod.__dict__.items()
- if callable(v)
- and (getattr(v, "__module__", None) or mod.__name__) == mod.__name__
- }
- )
- return rv
- @make_function_id_set
- def _builtin_constant_ids():
- """
- Collects constant builtins by eliminating callable items.
- """
- rv = {
- id(v): f"builtins.{k}"
- for k, v in builtins.__dict__.items()
- if not k.startswith("_") and not callable(v)
- }
- return rv
- def is_allowed(obj):
- """Is this safe to trace like torch.add ?"""
- # torch.ops is populated lazily so we don't necessarily have them in
- # _allowed_function_ids. Figure it out by testing the type instead
- # in those cases
- return id(obj) in _allowed_function_ids or isinstance(
- obj,
- (torch._ops.OpOverloadPacket, torch._ops.OpOverload, torch._ops._OpNamespace),
- )
- def torch_get_name(obj, default):
- """Convert a torch.* funcion to a string"""
- return _allowed_function_ids.get_name(id(obj), default)
- def is_builtin_callable(obj):
- return id(obj) in _builtin_function_ids
- def is_builtin_constant(obj):
- return id(obj) in _builtin_constant_ids
- def is_numpy(obj):
- if HAS_NUMPY:
- return isinstance(obj, np.ndarray) or id(obj) in _numpy_function_ids
- else:
- return False
|