| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282 | import builtinsimport collectionsimport copyimport functoolsimport inspectimport itertoolsimport mathimport operatorimport typesimport warningsfrom typing import Dict, Optional, Setimport torchfrom torch.fx._symbolic_trace import is_fx_tracingfrom . import configfrom .external_utils import is_compilingfrom .utils import HAS_NUMPY, is_safe_constant, np"""A note on allowed functions:Dynamo consults this file to determine if a particular function/moduleis allowed to appear as a node in its fx output.If a function is disallowed, it may either be traced-through, or skipped.Trace-through means dynamo will continue to trace the interior code forthe function/module rather than stopping at its boundary and recording itas a node in the fx graph. Whether tracing through or allowing, the functionalityof the function/module is part of the dynamo graph.  Caveat: if tracing through,any interior operation could trigger its own graph-break.Skips are determined by (torch/_dynamo/skipfiles.py) - see "a note onskipfiles" there."""def make_function_id_set(lazy_initializer):    """    Track a set of `id()`s of objects which are either allowed or not    allowed to go into the generated FX graph.  Use to test for torch.*,    numpy.*, builtins.*, etc.    Support user modification to permit customization of what can be    added to the graph and what will cause a graph break.    """    class FunctionIdSet:        function_ids: Optional[Set[int]] = None        function_names: Optional[Dict[int, str]] = None        def __call__(self):            if self.function_ids is None:                value = lazy_initializer()                if isinstance(value, dict):                    self.function_ids = set(value.keys())                    self.function_names = value                else:                    assert isinstance(value, set)                    self.function_ids = value            return self.function_ids        def get_name(self, idx: int, default: str):            self()  # lazy init            return self.function_names.get(idx, default)        def add(self, idx: int):            self()  # lazy init            self.function_ids.add(idx)        def remove(self, idx: int):            if idx in self():                self.function_ids.remove(idx)        def __contains__(self, idx: int):            return idx in self()    return FunctionIdSet()@make_function_id_setdef _disallowed_function_ids():    remove = [        True,        False,        None,        collections.OrderedDict,        copy.copy,        copy.deepcopy,        inspect.signature,        math.__package__,        torch.__builtins__,        torch.autocast_decrement_nesting,        torch.autocast_increment_nesting,        torch.autograd.grad,        torch.clear_autocast_cache,        torch.cuda.current_device,        torch.cuda.amp.autocast_mode.autocast,        torch.cpu.amp.autocast_mode.autocast,        torch.distributions.constraints.is_dependent,        torch.distributions.normal.Normal,        torch.inference_mode,        torch.set_anomaly_enabled,        torch.set_autocast_cache_enabled,        torch.set_autocast_cpu_dtype,        torch.set_autocast_cpu_enabled,        torch.set_autocast_enabled,        torch.set_autocast_gpu_dtype,        torch.autograd.profiler.profile,        warnings.warn,        torch._C._dynamo.eval_frame.unsupported,    ]    # extract all dtypes from torch    dtypes = [        obj for obj in torch.__dict__.values() if isinstance(obj, type(torch.float32))    ]    remove += dtypes    storage = [        obj        for obj in torch.__dict__.values()        if isinstance(obj, type(torch.FloatStorage))    ]    remove += storage    return {id(x) for x in remove}@make_function_id_setdef _allowed_function_ids():    """    Walk torch.* and get the ids of all the stuff in it    """    warnings.filterwarnings("ignore", category=UserWarning, module="torch.distributed")    torch_object_ids = dict()    def _is_allowed_module_prefix(obj):        allowed_modules = ("torch", "math")        # torch.nn.modules.rnn is disallowed because these modules internally        # flatten their parameters.  This flattening process will call        # Tensor.set_ with a Storage, and Storages cannot be traced with        # AOTAutograd; so we need to graph-break. To ensure this, we inline        # these functions, rather than keep them opaque-ly in the graph.        disallowed_modules = (            "torch.optim.",            "torch.nn.modules.rnn.",            "torch._dynamo.",            "torch._C._dynamo.",            "torch._inductor.",            "torch._C.inductor.",            "torch.fx.",            "torch.distributed.fsdp.",        )        allowed_modules_dot = tuple([x + "." for x in allowed_modules])        module = inspect.getmodule(obj)        if module is None:            return False        mod_name = module.__name__        if any(mod_name.startswith(m) for m in disallowed_modules):            return False        return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot)    def _find_torch_objects(module):        if any(            module.__name__.startswith(mod_name)            for mod_name in config.allowed_functions_module_string_ignorelist        ):            return        torch_object_ids[id(module)] = module.__name__        for name, obj in list(module.__dict__.items()):            if id(obj) not in torch_object_ids:                if isinstance(obj, types.ModuleType):                    if obj.__name__.startswith("torch.") and _is_allowed_module_prefix(                        obj                    ):                        torch_object_ids[id(obj)] = f"{module.__name__}.{name}"                        _find_torch_objects(obj)                elif _is_allowed_module_prefix(obj):                    torch_object_ids[id(obj)] = f"{module.__name__}.{name}"                elif inspect.getmodule(obj) is None and not is_safe_constant(obj):                    torch_object_ids[id(obj)] = f"{module.__name__}.{name}"    _find_torch_objects(torch)    _find_torch_objects(math)    # torch.Tensor.{fn}    for name in dir(torch.Tensor):        method = getattr(torch.Tensor, name)        if isinstance(method, types.MethodDescriptorType):            torch_object_ids[id(method)] = f"torch.Tensor.{name}"    for idx in _disallowed_function_ids():        if idx in torch_object_ids:            del torch_object_ids[idx]    for extra in (is_fx_tracing, is_compiling):        torch_object_ids[id(extra)] = f"{extra.__module__}.{extra.__name__}"    return torch_object_ids@make_function_id_setdef _builtin_function_ids():    rv = {        id(v): f"builtins.{k}"        for k, v in builtins.__dict__.items()        if not k.startswith("_") and callable(v)    }    rv.update(        {            id(v): f"operator.{k}"            for k, v in operator.__dict__.items()            if not k.startswith("_") and callable(v)        }    )    rv.update(        {id(v): f"functools.{v.__name__}" for v in (itertools.chain, itertools.islice)}    )    rv[id(functools.reduce)] = "functools.reduce"    return rv@make_function_id_setdef _numpy_function_ids():    rv = dict()    if HAS_NUMPY:        for mod in (np, np.random):            rv.update(                {                    id(v): f"{mod.__name__}.{k}"                    for k, v in mod.__dict__.items()                    if callable(v)                    and (getattr(v, "__module__", None) or mod.__name__) == mod.__name__                }            )    return rv@make_function_id_setdef _builtin_constant_ids():    """    Collects constant builtins by eliminating callable items.    """    rv = {        id(v): f"builtins.{k}"        for k, v in builtins.__dict__.items()        if not k.startswith("_") and not callable(v)    }    return rvdef is_allowed(obj):    """Is this safe to trace like torch.add ?"""    # torch.ops is populated lazily so we don't necessarily have them in    # _allowed_function_ids.  Figure it out by testing the type instead    # in those cases    return id(obj) in _allowed_function_ids or isinstance(        obj,        (torch._ops.OpOverloadPacket, torch._ops.OpOverload, torch._ops._OpNamespace),    )def torch_get_name(obj, default):    """Convert a torch.* funcion to a string"""    return _allowed_function_ids.get_name(id(obj), default)def is_builtin_callable(obj):    return id(obj) in _builtin_function_idsdef is_builtin_constant(obj):    return id(obj) in _builtin_constant_idsdef is_numpy(obj):    if HAS_NUMPY:        return isinstance(obj, np.ndarray) or id(obj) in _numpy_function_ids    else:        return False
 |