123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648 |
- import contextlib
- import ctypes
- import inspect
- import sys
- import types
- from abc import ABC
- from typing import Any, Dict
- import torch._C
- from torch import _utils_internal
- from torch._functorch.pyfunctorch import dispatch_functorch
- # Query `hasattr` only once.
- _SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
- @contextlib.contextmanager
- def dl_open_guard():
- """
- Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a
- shared library to load custom operators.
- """
- if not _SET_GLOBAL_FLAGS:
- yield
- return
- old_flags = sys.getdlopenflags()
- sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL)
- try:
- yield
- finally:
- sys.setdlopenflags(old_flags)
- def has_key(op, k):
- return (
- torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), k)
- or k in op.py_kernels
- )
- # TODO(voz) We are missing an entire axis of registration - Modes for the python key
- class PyOperatorABC(ABC):
- def __call__(self, *args, **kwargs):
- pass
- def py_impl(self, dispatch_key, fn):
- pass
- def name(self):
- pass
- is_included_in_alias = torch._C._dispatch_is_included_in_alias
- DispatchKey = torch._C.DispatchKey
- # Equivalent to computeDispatchTableEntryWithDebug
- def resolve_key(op: PyOperatorABC, k: DispatchKey): # type: ignore[valid-type]
- # 1. (Direct) operator registration
- if has_key(op, k):
- return k
- # 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available
- cand = DispatchKey.CompositeExplicitAutogradNonFunctional
- if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(
- op, cand
- ):
- return cand
- # 2.2 Use CompositeExplicitAutograd kernel if available
- cand = DispatchKey.CompositeExplicitAutograd
- if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(
- op, cand
- ):
- return cand
- has_backend_kernel = torch._C._dispatch_has_kernel_for_any_dispatch_key(
- op.name(), torch._C._dispatch_get_backend_keyset_from_autograd(k)
- ) or has_key(op, DispatchKey.CompositeExplicitAutograd)
- # 2.3. Use CompositeImplicitAutograd kernel if available
- cand = DispatchKey.CompositeImplicitAutogradNestedTensor
- if (
- (k != DispatchKey.Undefined and is_included_in_alias(k, cand))
- and has_key(op, cand)
- and not has_backend_kernel
- ):
- return cand
- cand = DispatchKey.CompositeImplicitAutograd
- if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(
- op, cand
- ):
- if (
- k == DispatchKey.AutogradOther
- and torch._C._dispatch_has_kernel_for_any_dispatch_key(
- op.name(), torch._C._dispatch_autogradother_backends
- )
- ):
- raise RuntimeError("ambiguous autogradother kernel")
- elif not has_backend_kernel:
- return cand
- # 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available
- cand = DispatchKey.Autograd
- if is_included_in_alias(k, cand) and has_key(op, cand):
- return cand
- # Backend fallback
- if torch._C._dispatch_has_backend_fallback(k):
- # The dispatch key itself will implicitly route to backend fallback.
- # This is probably not great for the pure Python implementation.
- return k
- raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}")
- pyop_namespace = {}
- class PyOperator(PyOperatorABC):
- def __init__(self, name):
- self._name = name
- self.table = {}
- self.python_key_mode_table = {}
- self.functorch_table = {}
- # Make _OPNamespace not scream, this whole name based association needs a good hard look
- self.__name__ = name
- pyop_namespace[name] = self
- def fallthrough(self, dispatch_key):
- self.table[dispatch_key] = self._fallthrough_fn(self, dispatch_key)
- def py_impl(self, dispatch_key_or_mode_or_transform):
- def inner(fn):
- if inspect.isclass(dispatch_key_or_mode_or_transform) and issubclass(
- dispatch_key_or_mode_or_transform,
- torch.utils._python_dispatch.TorchDispatchMode,
- ):
- mode = dispatch_key_or_mode_or_transform
- assert mode not in self.python_key_mode_table
- # TODO(voz): Should we replace setting torch._C.DispatchKey.Python entirely with setting mode keys?
- self.python_key_mode_table[mode] = fn
- return fn
- if isinstance(
- dispatch_key_or_mode_or_transform, torch._C._functorch.TransformType
- ):
- transform = dispatch_key_or_mode_or_transform
- self.functorch_table[transform] = fn
- return fn
- dispatch_key = dispatch_key_or_mode_or_transform
- assert (
- dispatch_key != torch._C.DispatchKey.Python
- ), "Please register a mode for the torch._C.DispatchKey.Python key instead."
- assert isinstance(dispatch_key, torch._C.DispatchKey)
- assert dispatch_key not in self.table
- self.table[dispatch_key] = fn
- return fn
- return inner
- def dispatch(self, dispatch_key, *args, **kwargs):
- from torch.utils._python_dispatch import _get_current_dispatch_mode
- if dispatch_key == torch._C.DispatchKey.FuncTorchDynamicLayerFrontMode:
- return dispatch_functorch(self, args, kwargs)
- if dispatch_key == torch._C.DispatchKey.Python:
- # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
- curr_mode = _get_current_dispatch_mode()
- assert (
- curr_mode is not None
- ), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
- assert (
- type(curr_mode) in self.python_key_mode_table
- ), f"Current active mode {curr_mode} not registered"
- # TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key.
- return self.python_key_mode_table[type(curr_mode)](*args, **kwargs)
- assert dispatch_key in self.table, dispatch_key
- return self.table[dispatch_key](*args, **kwargs)
- def __call__(self, *args, **kwargs):
- flat_args = _to_flat_tuple(args, kwargs)
- if torch.overrides.has_torch_function(flat_args):
- return torch.overrides.handle_torch_function(
- self, flat_args, *args, **kwargs
- )
- dispatch_key_set = _compute_keyset(args, kwargs)
- return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
- def name(self):
- return self.name
- # TODO(voz): Should rewrite fallthrough register as the impl for keys we do not specify
- # as opposed to being this sort of explicit thing where ops are a little too key aware...
- def _fallthrough_fn(self, operator, dispatch_key):
- def inner(*args, **kwargs):
- all_keys_after_current = torch._C._dispatch_keyset_full_after(dispatch_key)
- all_keys_after_current_masked = all_keys_after_current & _compute_keyset(
- args, kwargs
- )
- return self.dispatch(
- all_keys_after_current_masked.highestPriorityTypeId(), *args, **kwargs
- )
- return inner
- def _to_flat_tuple(args, kwargs):
- flat_args, _ = torch.utils._pytree.tree_flatten(args)
- flat_kwargs, _ = torch.utils._pytree.tree_flatten(kwargs)
- flat_all = flat_args + flat_kwargs
- return flat_all
- def _compute_keyset(args, kwargs):
- tensors = _get_tensors(args, kwargs)
- return key_extractor(tensors)
- def _get_tensors(args, kwargs):
- flat_all = _to_flat_tuple(args, kwargs)
- tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)]
- return tuple(tensor_args)
- # Note - this should maintain identical impl to the C++ dispatcher key extraction logic
- # at ATen/core/dispatch/DispatchKeyExtractor.h
- def key_extractor(tensors):
- key_set = torch._C._dispatch_tls_local_include_set()
- for tensor in tensors:
- key_set = key_set | torch._C._dispatch_keys(tensor)
- key_set = key_set - torch._C._dispatch_tls_local_exclude_set()
- return key_set
- # Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
- # You can obtain an OpOverload object through attribute query on OpOverloadPacket.
- class OpOverload(PyOperatorABC):
- def __init__(self, overloadpacket, op, op_dk, schema, tags):
- self._op = op
- self._op_dk = op_dk
- self._schema = schema
- self._overloadpacket = overloadpacket
- self._tags = tags
- self._overloadname = (
- "default" if schema.overload_name == "" else schema.overload_name
- )
- self._name = self._schema.name
- if schema.overload_name:
- self._name += "." + schema.overload_name
- self.py_kernels: Dict[torch._C.DispatchKey, Any] = {} # type: ignore[name-defined]
- self.__name__ = "{}.{}".format(
- self._schema.name.split("::")[1], self._overloadname
- )
- # TODO(voz): Lots of shared logic around python_key_mode_table, maybe pull into base...
- self.python_key_mode_table = {}
- self.__module__ = overloadpacket.__module__
- op.__module__ = overloadpacket.__module__
- self.__qualname__ = self._name
- self.__annotations__ = {}
- # NB: This name is hard-coded in torch/csrc/autograd/python_variable.cpp
- self._dispatch_cache = {}
- # Logic replicated from aten/src/ATen/native/MathBitsFallback.h
- is_write = None
- for a in self._schema.arguments:
- if a.alias_info is None:
- continue
- if is_write is None:
- is_write = a.alias_info.is_write
- else:
- # We will conservatively call mixed mutable/non-mutable
- # aliased inputs as NOT a view
- is_write = a.alias_info.is_write or is_write
- self.is_view = is_write is not None and not is_write
- # it's a no-op since OpOverload object is immutable and must be unique for a given op overload.
- def __deepcopy__(self, memo=None):
- return self
- def __repr__(self):
- return "<OpOverload(op='{}.{}', overload='{}')>".format(
- *self._schema.name.split("::"), self._overloadname
- )
- def __call__(self, *args, **kwargs):
- return self._op(*args, **kwargs or {})
- def __hash__(self):
- return hash(self._op)
- # `my_namespace.my_op_name.overload_name`
- def __str__(self):
- return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname)
- @property
- def namespace(self):
- return self._schema.name.split("::")[0]
- def decompose(self, *args, **kwargs):
- dk = torch._C.DispatchKey.CompositeImplicitAutograd
- if dk in self.py_kernels:
- # NB: This branch is not too necessary anymore, because we can
- # apply Python CompositeImplicitAutograd *before* tracing
- # using Python dispatcher (also taking advantage of the autograd
- # formula). But it's included for completeness
- return self.py_kernels[dk](*args, **kwargs)
- elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk):
- return self._op_dk(dk, *args, **kwargs)
- else:
- return NotImplemented
- def py_impl(self, dispatch_key_or_mode):
- def inner(fn):
- if inspect.isclass(dispatch_key_or_mode) and issubclass(
- dispatch_key_or_mode, torch.utils._python_dispatch.TorchDispatchMode
- ):
- mode = dispatch_key_or_mode
- assert mode not in self.python_key_mode_table
- # TODO(voz): Should we replace setting torch._C.DispatchKey.Python entirely with setting mode keys?
- self.python_key_mode_table[mode] = fn
- self._dispatch_cache.clear()
- return fn
- assert isinstance(dispatch_key_or_mode, torch._C.DispatchKey)
- assert (
- dispatch_key_or_mode != torch._C.DispatchKey.Python
- ), "Please register a mode for the torch._C.DispatchKey.Python key instead."
- if dispatch_key_or_mode in self.py_kernels:
- raise RuntimeError(
- f"Trying to override a python impl for {dispatch_key_or_mode} on operator {self._name}"
- )
- self.py_kernels[dispatch_key_or_mode] = fn
- self._dispatch_cache.clear()
- return fn
- return inner
- # Remove a dispatch key from the dispatch cache. This will force it to get
- # recomputed the next time. Does nothing
- # WARNING: if you register a dispatch key to py_kernels of an OpOverload,
- # calling _del_dispatch on that key is NOT sufficient to apply your change,
- # because a single registration may affect MULTIPLE dispatch keys (e.g.,
- # registering Autograd affects AutogradCPU). del_dispatch is to be used
- # only if you are specifically modifying how get_dispatch handles a
- # particular input 'key'.
- def _uncache_dispatch(self, key):
- self._dispatch_cache.pop(key, None)
- # This implements the pre-computation logic for the Python dispatcher.
- def _get_dispatch(self, key):
- # This is only called upon a cache miss
- assert key not in self._dispatch_cache, f"{self} {key}"
- if key == torch._C.DispatchKey.Python:
- if not self.python_key_mode_table:
- self._dispatch_cache[key] = key
- return key
- def handler(*args, **kwargs):
- from torch.utils._python_dispatch import _get_current_dispatch_mode
- # TODO: We also need to handle tensor subclasses here
- # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
- curr_mode = type(_get_current_dispatch_mode())
- assert (
- curr_mode is not None
- ), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
- if curr_mode not in self.python_key_mode_table:
- # TODO: This path is slow, should generally encourage this
- # case to not happen
- return self._op_dk(key, *args, **kwargs)
- # TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key.
- return self.python_key_mode_table[curr_mode](*args, **kwargs)
- self._dispatch_cache[key] = handler
- return handler
- final_key = resolve_key(self, key)
- # TODO: We could potentially have lots of debugging wrappers against
- # dispatch keys; design some general registration mechanism instead of
- # having if statement for each of them
- if key == torch._C.DispatchKey.Functionalize:
- import torch._dispatch.python as pydispatch
- if pydispatch.CROSSREF_FUNCTIONALIZE:
- handler = pydispatch.make_crossref_functionalize(self, final_key)
- self._dispatch_cache[key] = handler
- return handler
- # print(self, key, final_key)
- r = self.py_kernels.get(final_key, final_key)
- self._dispatch_cache[key] = r
- return r
- def name(self):
- return self._name
- @property
- def overloadpacket(self):
- return self._overloadpacket
- @property
- def op(self):
- return self._op
- @property
- def tags(self):
- return self._tags
- # TODO: add more methods to expose information about input and output arguments
- # OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator
- # You can obtain an OpOverload object through attribute query.
- class OpOverloadPacket:
- def __init__(self, qualified_op_name, op_name, op, overload_names):
- # These attributes are accessible on the object through the properties
- # defined below but are immutable
- self._qualified_op_name = qualified_op_name
- self.__name__ = op_name
- self._op = op
- self._overload_names = overload_names
- self._dir = []
- # it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op.
- def __deepcopy__(self, memo=None):
- return self
- def __repr__(self):
- return "<OpOverloadPacket(op='{}.{}')>".format(
- *self._qualified_op_name.split("::")
- )
- def __hash__(self):
- return hash(self._op)
- def __str__(self):
- return "{}.{}".format(*self._qualified_op_name.split("::"))
- @property
- def op(self):
- return self._op
- def __getattr__(self, key):
- # It is not a valid op_name when __file__ is passed in
- if key == "__file__":
- return "torch.ops"
- # ensure that query for dunder attributes that does not exist on
- # opoverloadpacket but instead exists on the self._op object does not unnecessarily call
- # `_get_operation_overload` (which is an expensive operation).
- # This is done to prevent any potential slowdown. This list can be extended
- # if there exists other attributes like `__name__` that only exist on self._op and not on the
- # opoverloadpacket.
- # This is ok since we are guaranteed that an overload name for an aten op can't start with '__'
- try:
- if key.startswith("__"):
- return getattr(self._op, key)
- except AttributeError:
- # for consistency because it seems weird to
- # throw an attribute error with a message containing
- # an object name different from the one the attribute
- # query was performed on.
- raise AttributeError(
- "'{}' can't have an overload name beginning with '__' and the "
- "underlying op {} has no attribute {} either.".format(
- str(self), str(self._op), key
- )
- ) from None
- try:
- # This is ok since we are guaranteed that an overload name for an aten op can't be 'default'
- use_key = "" if key == "default" else key
- # TODO: disallow access to overloads registered by JIT
- op_, op_dk_, tags = torch._C._get_operation_overload(
- self._qualified_op_name, use_key
- )
- schema = torch._C._get_schema(self._qualified_op_name, use_key)
- overload = OpOverload(self, op_, op_dk_, schema, tags)
- # cache the overload object
- setattr(self, key, overload)
- self._dir.append(key)
- return overload
- except RuntimeError:
- raise AttributeError(
- "The underlying op of '{}' has no overload name '{}'".format(
- str(self), key
- )
- ) from None
- def __iter__(self):
- return iter(self._dir)
- def __call__(self, *args, **kwargs):
- # overloading __call__ to ensure torch.ops.foo.bar()
- # is still callable from JIT
- # We save the function ptr as the `op` attribute on
- # OpOverloadPacket to access it here.
- return self._op(*args, **kwargs or {})
- # TODO: use this to make a __dir__
- def overloads(self):
- return [n if n else "default" for n in self._overload_names]
- # Resolution of torch.fn is different from torch.ops.aten.fn
- # torch.fn uses the Python argparser, matches with the
- # appropriate schema, and calls into the unboxed version of the method
- # torch.ops.aten.fn resolution is done via the mechanism defined in JIT.
- # JIT creates a stack of all the overloads and then tries to match the
- # correct one at runtime and always calls into the boxed version of the method
- # Autograd codegen creates VariableType, TracerType,
- # inplace or view type and python bindings.
- # Aten codegen generates tensor methods for the the tensor class.
- # _OpNamespace is a subclass of ModuleType because the torch script
- # allows attribute lookups on modules only. Since we want torch.ops.foo.bar()
- # to work from script, we need to ensure ops and foo are modules
- class _OpNamespace(types.ModuleType):
- """
- An op namespace to dynamically bind Operators into Python.
- Say a user has created a custom Operator called "my_namespace::my_op". To
- call this op, the user will write torch.ops.my_namespace.my_op(...).
- At startup, this operation will not yet be bound into Python. Instead, the
- following sequence of magic tricks will occur:
- 1. `torch.ops.my_namespace` will invoke the `__getattr__` magic method
- on the `torch.ops` object, which will create a new `_OpNamespace`
- object called `my_namespace` and set it as an attribute on the `ops`
- object.
- 2. `torch.ops.my_namespace.my_op` will then invoke `__getattr__` on
- the `my_namespace` object, which will retrieve the operation via
- `torch.get_operation`, a function bound from C++, and then in a similar
- fashion bind this new object onto the `my_namespace` object.
- 3. `torch.ops.my_namespace.my_op(...)` then calls this new operation
- and subsequent accesses will incur no further lookup (the namespace and
- operation will already exist).
- """
- def __init__(self, name):
- super().__init__("torch.ops." + name)
- self.name = name
- self._dir = []
- def __iter__(self):
- return iter(self._dir)
- def __getattr__(self, op_name):
- # It is not a valid op_name when __file__ is passed in
- if op_name == "__file__":
- return "torch.ops"
- elif op_name == "__origin__":
- raise AttributeError()
- # Get the op `my_namespace::my_op` if available. This will also check
- # for overloads and raise an exception if there are more than one.
- namespace_name = self.name
- qualified_op_name = "{}::{}".format(namespace_name, op_name)
- try:
- op, overload_names = torch._C._jit_get_operation(qualified_op_name)
- except RuntimeError as e:
- # Turn this into AttributeError so getattr(obj, key, default)
- # works (this is called by TorchScript with __origin__)
- raise AttributeError(
- f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'"
- ) from e
- # let the script frontend know that op is identical to the builtin op
- # with qualified_op_name
- torch.jit._builtins._register_builtin(op, qualified_op_name)
- op.__module__ = self.__module__ + "." + namespace_name
- opoverloadpacket = OpOverloadPacket(
- qualified_op_name, op_name, op, overload_names
- )
- opoverloadpacket.__module__ = self.__module__ + "." + namespace_name
- # cache the opoverloadpacket to ensure that each op corresponds to
- # a unique OpOverloadPacket object
- setattr(self, op_name, opoverloadpacket)
- self._dir.append(op_name)
- return opoverloadpacket
- class _PyOpNamespace(_OpNamespace):
- def __init__(self):
- super().__init__("torch.ops")
- self.pyop_namespace = pyop_namespace
- class _Ops(types.ModuleType):
- __file__ = "_ops.py"
- def __init__(self):
- super().__init__("torch.ops")
- self.loaded_libraries = set()
- self.pyops = _PyOpNamespace()
- self._dir = []
- def __getattr__(self, name):
- # Check if the name is a pyop
- if name in self.pyops.pyop_namespace:
- return self.pyops.pyop_namespace[name]
- # Here we are creating `torch.ops.my_namespace`
- namespace = _OpNamespace(name)
- setattr(self, name, namespace)
- self._dir.append(name)
- return namespace
- def __iter__(self):
- return iter(self._dir)
- def load_library(self, path):
- """
- Loads a shared library from the given path into the current process.
- The library being loaded may run global initialization code to register
- custom operators with the PyTorch JIT runtime. This allows dynamically
- loading custom operators. For this, you should compile your operator
- and the static registration code into a shared library object, and then
- call ``torch.ops.load_library('path/to/libcustom.so')`` to load the
- shared object.
- After the library is loaded, it is added to the
- ``torch.ops.loaded_libraries`` attribute, a set that may be inspected
- for the paths of all libraries loaded using this function.
- Args:
- path (str): A path to a shared library to load.
- """
- if sys.executable == "torch_deploy":
- return
- path = _utils_internal.resolve_library_path(path)
- with dl_open_guard():
- # Import the shared library into the process, thus running its
- # static (global) initialization code in order to register custom
- # operators with the JIT.
- ctypes.CDLL(path)
- self.loaded_libraries.add(path)
- # The ops "namespace"
- ops = _Ops()
|