1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585 |
- """TorchScript
- This module contains functionality to support the JIT's scripting frontend, notably:
- - torch.jit.script
- This is not intended to be imported directly; please use the exposed
- functionalities in `torch.jit`.
- """
- import functools
- import collections
- import enum
- import inspect
- import copy
- import pickle
- import warnings
- from typing import Any, Dict, List, Set, Tuple, Union, Callable
- import torch
- import torch._jit_internal as _jit_internal
- from torch.utils import set_module
- from torch.jit._recursive import ScriptMethodStub, wrap_cpp_module, infer_methods_to_compile, _compile_and_register_class
- from torch.nn import Module
- from torch.jit._state import _enabled
- from torch.jit._builtins import _register_builtin
- from torch.jit.frontend import get_jit_def, get_default_args, get_jit_class_def
- from torch._jit_internal import _qualified_name
- from torch.jit._fuser import _graph_for, _script_method_graph_for
- from torch.jit._state import (
- _try_get_jit_cached_function,
- _try_get_jit_cached_overloads,
- _set_jit_function_cache,
- _set_jit_overload_cache,
- )
- from torch.overrides import (
- has_torch_function, has_torch_function_unary, has_torch_function_variadic)
- from torch.package import PackageExporter, PackageImporter
- from ._serialization import validate_map_location
- from torch.jit._monkeytype_config import (
- monkeytype_trace,
- JitTypeTraceConfig ,
- JitTypeTraceStore
- )
- from torch._classes import classes
- type_trace_db = JitTypeTraceStore() # DB to hold all call traces from MonkeyType
- torch._C.ScriptMethod.graph_for = _script_method_graph_for # type: ignore[attr-defined]
- torch._C.ScriptFunction.graph_for = _graph_for # type: ignore[attr-defined]
- ScriptFunction = torch._C.ScriptFunction
- ScriptFunction.__doc__ = """
- Functionally equivalent to a :class:`ScriptModule`, but represents a single
- function and does not have any attributes or Parameters.
- """
- set_module(ScriptFunction, "torch.jit")
- # Throws an error if a jit function is pickled.
- # Helps to avoid Python crashes for Python versions 3.9.5 + when protocol 0 or 1 is given as an argument.
- def _reduce(cls):
- raise pickle.PickleError("ScriptFunction cannot be pickled")
- ScriptFunction.__reduce__ = _reduce # type: ignore[assignment]
- if _enabled:
- Attribute = collections.namedtuple("Attribute", ["value", "type"])
- else:
- def Attribute(value, type): # type: ignore[no-redef]
- return value
- Attribute.__doc__ = """
- This method is a pass-through function that returns `value`, mostly
- used to indicate to the TorchScript compiler that the left-hand side
- expression is a class instance attribute with type of `type`. Note that
- `torch.jit.Attribute` should only be used in `__init__` method of `jit.ScriptModule`
- subclasses.
- Though TorchScript can infer correct type for most Python expressions, there are some cases where
- type inference can be wrong, including:
- - Empty containers like `[]` and `{}`, which TorchScript assumes to be container of `Tensor`
- - Optional types like `Optional[T]` but assigned a valid value of type `T`, TorchScript would assume
- it is type `T` rather than `Optional[T]`
- In eager mode, it is simply a pass-through function that returns `value`
- without other implications.
- Example:
- .. testcode::
- import torch
- from typing import Dict
- class AttributeModule(torch.jit.ScriptModule):
- def __init__(self):
- super().__init__()
- self.foo = torch.jit.Attribute(0.1, float)
- # we should be able to use self.foo as a float here
- assert 0.0 < self.foo
- self.names_ages = torch.jit.Attribute({}, Dict[str, int])
- self.names_ages["someone"] = 20
- assert isinstance(self.names_ages["someone"], int)
- m = AttributeModule()
- # m will contain two attributes
- # 1. foo of type float
- # 2. names_ages of type Dict[str, int]
- .. testcleanup::
- del AttributeModule
- del m
- Note: it's now preferred to instead use type annotations instead of `torch.jit.Annotate`:
- .. testcode::
- import torch
- from typing import Dict
- class AttributeModule(torch.nn.Module):
- names: Dict[str, int]
- def __init__(self):
- super().__init__()
- self.names = {}
- m = AttributeModule()
- .. testcleanup::
- del AttributeModule
- del m
- Args:
- value: An initial value to be assigned to attribute.
- type: A Python type
- Returns:
- Returns `value`
- """
- def _get_type_trace_db():
- # This is a private API. Use of this for external purposes is discouraged.
- return type_trace_db
- # Gets a function from the name of a method on a type
- def _get_function_from_type(cls, name):
- return getattr(cls, name, None)
- # ScriptClasses must be new-style classes because we construct them using their
- # __new__ method.
- def _is_new_style_class(cls):
- if hasattr(cls, "__class__"):
- return "__dict__" in dir(cls) or hasattr(cls, "__slots__")
- # These OrderedDictWrapper classes replace the actual OrderedDicts in
- # module with versions that get/set properties inside of Module.
- # This allows us to reuse most of nn.Module while still storing the
- # data in C++.
- # Each OrderedDict needs to support:
- # x not in view
- # x in view
- # view[name] = ...
- # view.values()
- # del view[name]
- # view.items()
- # view.keys()
- # len(view)
- class OrderedDictWrapper:
- def __init__(self, _c):
- self._c = _c
- def keys(self):
- return [k for k, v in self.items()]
- def values(self):
- return [v for k, v in self.items()]
- def __len__(self):
- return len(self.values())
- def __delitem__(self, k):
- raise RuntimeError("cannot delete methods or parameters of a script module")
- def items(self):
- return self._c.items()
- def __setitem__(self, k, v):
- if k not in self:
- raise RuntimeError(
- "Can't add a new parameter after ScriptModule construction."
- " Tried to add '{}".format(k)
- )
- self._c.setattr(k, v)
- def __contains__(self, k):
- return self._c.contains(k)
- def __getitem__(self, k):
- if k not in self:
- raise KeyError(k)
- return self._c.getattr(k)
- class OrderedModuleDict(OrderedDictWrapper):
- def __init__(self, module, python_dict):
- super().__init__(torch._C.ModuleDict(module))
- # contains _both_ script modules and non-script python-only modules
- # because script modules are subclassed in python and the
- # C++ Module class will not hold references to them,
- # to ensure that you always get the same python value here
- # we store it in the python dict as well
- self._python_modules = python_dict
- def items(self):
- r = self._python_modules.items()
- return r
- def __contains__(self, k):
- return k in self._python_modules
- def __setitem__(self, k, v):
- # Cases where sub-module can be re-assigned after ScriptModule construction
- # 1. If the attr is an module interface type, it's guaranteed that the module is
- # not inlined in the graph, so it's safe to swap a new ScriptModule in.
- # 2. if the new value if a ScriptModule with the same JIT type, IR won't change
- # and it's legit to swap a new module in.
- # In these two cases we allow swapping a new scripted module and update the
- # corresponding python module dict to keep sync.
- # Note: the value to be swapped in has to be ScriptModule instead of nn.Module,
- # otherwise it's illegal and we throw error.
- if isinstance(v, ScriptModule):
- self._c.setattr(k, v)
- self._python_modules[k] = v
- else:
- raise RuntimeError(
- "Cannot re-assign modules in a ScriptModule with non-scripted "
- "module, tried to replace existing module '{}': {}".format(k, v)
- )
- def __getitem__(self, k):
- return self._python_modules[k]
- # For each user-defined class that subclasses ScriptModule, this meta-class:
- # (1) finds all the methods annotated with @script_method in a ScriptModule and
- # removes them from the class attributes
- # (2) puts a wrapper around the class's __init__ method to recursively compile
- # all of the script_methods with the module after the original __init__ has
- # run. This has to occur after the user-defined __init__ so that submodules and
- # parameters are initialized _before_ the script compiler resolve references to
- # `self.param` or `self.module`.
- class ScriptMeta(type):
- def __init__(cls, name, bases, attrs): # noqa: B902
- # Aggregate all the ScriptMethods and constants from superclasses
- cls._methods: Dict[str, Any] = {}
- cls._constants_set = set(getattr(cls, "__constants__", ()))
- for base in reversed(bases):
- for k, v in getattr(base, "_methods", {}).items():
- cls._methods[k] = v
- base_constants: Set = getattr(base, "_constants_set", set())
- cls._constants_set = cls._constants_set.union(base_constants)
- # find all the script methods of the current class
- for k, v in sorted(attrs.items()):
- if isinstance(v, ScriptMethodStub):
- delattr(cls, k)
- cls._methods[v.original_method.__name__] = v
- if getattr(cls, "_disable_script_meta", False):
- # We leave built-in ScriptModule types alone, since this metaclass
- # is only for compiling user classes that inherit from
- # ScriptModule.
- return super(ScriptMeta, cls).__init__(name, bases, attrs)
- original_init = getattr(cls, "__init__", lambda self: None)
- @functools.wraps(original_init)
- def init_then_script(self, *args, **kwargs):
- num_methods = len(cls._methods)
- original_init(self, *args, **kwargs)
- added_methods_in_init = len(cls._methods) > num_methods
- if type(self) == cls:
- def make_stubs(module):
- cls = type(module)
- if hasattr(cls, "_methods"):
- return [v for k, v in sorted(cls._methods.items())]
- else:
- return infer_methods_to_compile(module)
- self.__dict__[
- "_actual_script_module"
- ] = torch.jit._recursive.create_script_module(self, make_stubs, share_types=not added_methods_in_init)
- # Delete the Python attributes that now shadow the ScriptModule
- # ones, so that __getattr__ and __setattr__ will properly find
- # the scripted versions.
- concrete_type = self._actual_script_module._concrete_type
- for name in concrete_type.get_attributes():
- delattr(self, name)
- for name, _ in concrete_type.get_modules():
- delattr(self, name)
- for name in ("_parameters", "_buffers", "_modules"):
- delattr(self, name)
- cls.__init__ = init_then_script # type: ignore[misc]
- super(ScriptMeta, cls).__init__(name, bases, attrs)
- class _CachedForward:
- def __get__(self, obj, cls):
- return self.__getattr__("forward") # type: ignore[attr-defined]
- class ScriptWarning(Warning):
- pass
- def script_method(fn):
- if not _enabled:
- return fn
- # NOTE: we need to traverse two frames here because the meta-class frame
- # for ScriptModule will be present, as opposed to invoking @script on a
- # a function or invoking define() on a CompilationUnit.
- # The stack will look like:
- #
- # 0. createResolutionCallback()
- # 1. script_method()
- # 2. ScriptModule metaclass frame
- # 3. Surrounding scope
- #
- # createResolutionCallback internally adds 1 to get us to the scope of this
- # function (the calling function). Adding 2 gets us to the proper surrounding scope.
- _rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=2)
- ast = get_jit_def(fn, fn.__name__, self_name="ScriptModule")
- return ScriptMethodStub(_rcb, ast, fn)
- class ConstMap:
- def __init__(self, const_mapping):
- self.const_mapping = const_mapping
- def __getattr__(self, attr):
- return self.const_mapping[attr]
- def unpackage_script_module(importer: PackageImporter, script_module_id: str) -> torch.nn.Module:
- """
- Called by ``torch.package.PackageImporter``'s Pickler's ``persistent_load`` function.
- Performs work of loading and returning a ScriptModule from a ``torch.package`` archive.
- """
- if not isinstance(importer.zip_reader, torch._C.PyTorchFileReader):
- raise RuntimeError(
- "Loading ScriptObjects from a PackageImporter created from a "
- "directory is not supported. Use a package archive file instead."
- )
- cu = torch._C.CompilationUnit()
- cpp_module = torch._C._import_ir_module_from_package(
- cu,
- importer.zip_reader,
- importer.storage_context,
- validate_map_location(importer.last_map_location),
- script_module_id,
- )
- return wrap_cpp_module(cpp_module)
- if _enabled:
- _magic_methods = [
- "__iter__",
- "__len__",
- "__neg__",
- "__mul__",
- "__contains__",
- "__add__",
- "__sub__",
- "__pow__",
- "__truediv__",
- "__mod__",
- "__ne__",
- "__eq__",
- "__lt__",
- "__gt__",
- "__le__",
- "__ge__",
- "__and__",
- "__or__",
- "__xor__",
- "__getitem__",
- "__setitem__",
- "__call__",
- "__int__",
- "__float__",
- "__bool__",
- "__str__",
- "__enter__",
- "__exit__",
- ]
- class RecursiveScriptClass:
- """
- An analogue of RecursiveScriptModule for regular objects that are not modules.
- This class is a wrapper around a torch._C.ScriptObject that represents an instance
- of a TorchScript class and allows it to be used in Python.
- Attributes:
- _c [torch._C.ScriptObject]: The C++ object to which attribute lookups and method
- calls are forwarded.
- _props [Dict[str, property]]: A dictionary of properties fetched from self._c and
- exposed on this wrppaer.
- """
- def __init__(self, cpp_class):
- super().__init__()
- self.__dict__["_initializing"] = True
- self._c = cpp_class
- # Add wrapped object's properties to this class instance.
- self._props = {prop.name: property(prop.getter, prop.setter) for prop in self._c._properties()}
- self.__dict__["_initializing"] = False
- def __getattr__(self, attr):
- if "_initializing" in self.__dict__ and self.__dict__["_initializing"]:
- return super().__getattr__(attr) # type: ignore[misc]
- if attr in self._props:
- return self._props[attr].fget() # type: ignore[call-arg, misc]
- return getattr(self._c, attr)
- def __setattr__(self, attr, value):
- if "_initializing" in self.__dict__ and self.__dict__["_initializing"]:
- return super().__setattr__(attr, value)
- if attr in self._props:
- return self._props[attr].fset(value) # type: ignore[call-arg, misc]
- setattr(self._c, attr, value)
- # Delegate calls to magic methods like __len__ to the C++ module backing the
- # RecursiveScriptClass.
- def forward_magic_method(self, method_name, *args, **kwargs):
- if not self._c._has_method(method_name):
- raise TypeError()
- self_method = self.__getattr__(method_name)
- return self_method(*args, **kwargs)
- def __getstate__(self):
- raise pickle.PickleError("ScriptClasses cannot be pickled")
- def __iadd__(self, other):
- if self._c._has_method("__iadd__"):
- return self.forward_magic_method("__iadd__", other)
- else:
- return self.forward_magic_method("__add__", other)
- for method_name in _magic_methods:
- def method_template(self, *args, **kwargs):
- return self.forward_magic_method(method_name, *args, **kwargs)
- setattr(RecursiveScriptClass, method_name, method_template)
- # this is a Python 'non-data descriptor' that causes the first access
- # to ScriptModule's forward to look up the forward method and stash
- # it in the objects dict. Due to the standard rules for attribute lookup,
- # subsequent lookups will just directly return the previously looked up method.
- # This is necessary because nn.Module defines forward as a method. If we
- # did nothing, __getattr__ would not be called. Instead we'd get nn.Module.forward
- # which always throws an exception.
- class ScriptModule(Module, metaclass=ScriptMeta):
- r"""
- A wrapper around C++ ``torch::jit::Module``. ``ScriptModule``\s
- contain methods, attributes, parameters, and
- constants. These can be accessed the same way as on a normal ``nn.Module``.
- """
- __jit_unused_properties__ = ['code', 'code_with_constants', 'graph', 'inlined_graph', 'original_name']
- def __init__(self):
- super().__init__()
- forward: Callable[..., Any] = _CachedForward() # type: ignore[assignment]
- def __getattr__(self, attr):
- if "_actual_script_module" not in self.__dict__:
- return super().__getattr__(attr)
- return getattr(self._actual_script_module, attr)
- def __setattr__(self, attr, value):
- if "_actual_script_module" not in self.__dict__:
- # Unwrap torch.jit.Attribute into a regular setattr + record
- # the provided type in __annotations__.
- #
- # This ensures that if we use the attr again in `__init__`, it
- # will look like the actual value, not an instance of Attribute.
- if isinstance(value, Attribute):
- # NB: Ensure that we set __annotations__ on the specific
- # class in question, and not on a superclass (which would
- # be wrong wrong wrong!).
- # See also https://github.com/pytorch/pytorch/issues/39463
- if "__annotations__" not in self.__class__.__dict__:
- self.__class__.__annotations__ = {}
- self.__annotations__[attr] = value.type
- value = value.value
- return super().__setattr__(attr, value)
- setattr(self._actual_script_module, attr, value)
- def define(self, src):
- if "_actual_script_module" in self.__dict__:
- # If we have completed initialization, just defer to the
- # backing RecursiveScriptModule to eagerly compile the provided
- # source.
- return self._actual_script_module.define(src)
- # Otherwise, we are still in the object's __init__.
- # In that case, add `src` as a stub to be compiled.
- #
- # We use frames_up=1 to get to the proper surrounding scope. The stack
- # will look like:
- # 0. createResolutionCallback
- # 1. define()
- # 2. surrounding scope.
- #
- # createResolutionCallback internally adds 1 to get us to our frame, then
- # we add 1 to get to the proper surrounding scope.
- rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=1)
- ast = torch._C._parse_source_def(src)
- self._methods[ast.name().name] = ScriptMethodStub(rcb, ast, None)
- def _replicate_for_data_parallel(self):
- return self._actual_script_module._replicate_for_data_parallel()
- def __reduce_package__(self, exporter: PackageExporter):
- """
- Called by ``torch.package.PackageExporter``'s Pickler's ``persistent_id`` when
- saving TorchScript objects. Performs act of saving a ScriptModule inside of
- a ``torch.package`` archive.
- Returns method to load the ScriptModule from a ``torch.package.PackageImporter``'s
- Pickler's ``persistent_load`` function.
- """
- script_module_id = exporter.get_unique_id()
- exporter.script_module_serializer.serialize(self._c, int(script_module_id))
- return (unpackage_script_module, (script_module_id,))
- class RecursiveScriptModule(ScriptModule):
- # XXX: RecursiveScriptModule inherits from ScriptModule for the sole
- # reason that it retains the existing isinstance(ScriptModule)
- # behavior.
- r"""
- The core data structure in TorchScript is the ``ScriptModule``. It is an
- analogue of torch's ``nn.Module`` and represents an entire model as a tree of
- submodules. Like normal modules, each individual module in a ``ScriptModule`` can
- have submodules, parameters, and methods. In ``nn.Module``\s methods are implemented
- as Python functions, but in ``ScriptModule``\s methods are implemented as
- TorchScript functions, a statically-typed subset of Python that contains all
- of PyTorch's built-in Tensor operations. This difference allows your
- ``ScriptModule``\s code to run without the need for a Python interpreter.
- ``ScriptModule``\s should not be created manually, instead use
- either :func:`tracing <torch.jit.trace>` or :func:`scripting <torch.jit.script>`.
- Tracing and scripting can be applied incrementally and :ref:`composed as necessary <Types>`.
- * Tracing records the tensor operations as executed with a set of example inputs and uses these
- operations to construct a computation graph. You can use the full dynamic behavior of Python with tracing,
- but values other than Tensors and control flow aren't captured in the graph.
- * Scripting inspects the Python code of the model
- and compiles it to TorchScript. Scripting allows the use of many `types`_ of values and supports dynamic control flow.
- Many, but not all features of Python are supported by the compiler, so changes to the source code may be necessary.
- """
- _disable_script_meta = True
- def __init__(self, cpp_module):
- self.__dict__["_initializing"] = True
- self._c = cpp_module
- super().__init__()
- # Delete the 'training' attribute set up by `Module.__init__`. It
- # will get set on the underlying cpp module, so we delete it here
- # to avoid this version shadowing the cpp module version.
- delattr(self, "training")
- @staticmethod
- def _construct(cpp_module, init_fn):
- """
- Construct a RecursiveScriptModule that's ready for use. PyTorch
- code should use this to construct a RecursiveScriptModule instead
- of instead of calling `__init__` directly, as it makes sure the
- object is properly finalized (and in the future, we may take
- control of how the RecursiveScriptModule instance is created).
- Args:
- cpp_module: The C++ Module that will hold the actual state of
- this RecursiveScriptModule instance.
- init_fn: Lambda that initializes the RecursiveScriptModule passed to it.
- """
- script_module = RecursiveScriptModule(cpp_module)
- init_fn(script_module)
- # Finalize the ScriptModule: replace the nn.Module state with our
- # custom implementations and flip the _initializing bit.
- RecursiveScriptModule._finalize_scriptmodule(script_module)
- return script_module
- @staticmethod
- def _finalize_scriptmodule(script_module):
- script_module._parameters = OrderedDictWrapper(
- torch._C.ParameterDict(script_module._c)
- )
- script_module._buffers = OrderedDictWrapper(
- torch._C.BufferDict(script_module._c)
- )
- script_module._modules = OrderedModuleDict(
- script_module._c, script_module._modules
- )
- script_module._initializing = False
- def _reconstruct(self, cpp_module):
- """
- Re-construct an instance of RecursiveScriptModule using an instance of a C++ module.
- Args:
- cpp_module: The C++ module that this RecursiveScriptModule will be rebuilt around.
- """
- self.__init__(cpp_module) # type: ignore[misc]
- # Copy the concrete type from the C++ module to this ScriptModule.
- self._concrete_type = torch._C.ConcreteModuleType.from_jit_type(
- self._c._type()
- )
- # Copy submodules from the C++ module to this ScriptModule.
- modules = {}
- for name, cpp_module in torch._C.ModuleDict(self._c).items():
- modules[name] = wrap_cpp_module(cpp_module)
- self._modules = OrderedModuleDict(self._c, modules) # type: ignore[assignment]
- # Copy parameters and buffers.
- self._parameters = OrderedDictWrapper(torch._C.ParameterDict(self._c)) # type: ignore[assignment]
- self._buffers = OrderedDictWrapper(torch._C.BufferDict(self._c)) # type: ignore[assignment]
- # Get rid of the functions from the old C++ module.
- self.__dict__ = {
- k: v
- for k, v in self.__dict__.items()
- if not isinstance(v, torch._C.ScriptMethod)
- }
- self.__dict__["_initializing"] = False
- @property
- def graph(self):
- r"""
- Returns a string representation of the internal graph for the
- ``forward`` method. See :ref:`interpreting-graphs` for details.
- """
- return self._c._get_method("forward").graph
- @property
- def inlined_graph(self):
- r"""
- Returns a string representation of the internal graph for the
- ``forward`` method. This graph will be preprocessed to inline all function and method calls.
- See :ref:`interpreting-graphs` for details.
- """
- return self.forward.inlined_graph # type: ignore[attr-defined]
- @property
- def code(self):
- r"""
- Returns a pretty-printed representation (as valid Python syntax) of
- the internal graph for the ``forward`` method. See
- :ref:`inspecting-code` for details.
- """
- return self.forward.code # type: ignore[attr-defined]
- @property
- def code_with_constants(self):
- r"""
- Returns a tuple of:
- [0] a pretty-printed representation (as valid Python syntax) of
- the internal graph for the ``forward`` method. See `code`.
- [1] a ConstMap following the CONSTANT.cN format of the output in [0].
- The indices in the [0] output are keys to the underlying constant's values.
- See :ref:`inspecting-code` for details.
- """
- r = self.forward.code_with_constants # type: ignore[attr-defined]
- return (r[0], ConstMap(r[1]))
- def save(self, f, **kwargs):
- r"""
- save(f, _extra_files={})
- See :func:`torch.jit.save <torch.jit.save>` for details.
- """
- return self._c.save(str(f), **kwargs)
- def _save_for_lite_interpreter(self, *args, **kwargs):
- r"""
- _save_for_lite_interpreter(f)
- Add (or update) the bytecode session to the script model. The updated model is used
- in lite interpreter for mobile applications.
- Args:
- f: a string containing a file name.
- _extra_files: Map from filename to contents which will be stored as part of 'f'.
- """
- return self._c._save_for_mobile(*args, **kwargs)
- def _save_to_buffer_for_lite_interpreter(self, *args, **kwargs):
- return self._c._save_to_buffer_for_mobile(*args, **kwargs)
- def save_to_buffer(self, *args, **kwargs):
- return self._c.save_to_buffer(*args, **kwargs)
- def get_debug_state(self, *args, **kwargs):
- return self._c.get_debug_state()
- def extra_repr(self):
- return "original_name={}".format(self.original_name)
- def graph_for(self, *args, **kwargs):
- return self.forward.graph_for(self, *args, **kwargs) # type: ignore[attr-defined]
- @property
- def original_name(self):
- if type(self) == str(self._c._type().name()):
- return ""
- return str(self._c._type().name())
- def define(self, src):
- # We use frames_up=1 to get to the proper surrounding scope. The stack
- # will look like:
- # 0. createResolutionCallback
- # 1. define()
- # 2. surrounding scope.
- #
- # createResolutionCallback internally adds 1 to get us to our frame, then
- # we add 1 to get to the proper surrounding scope.
- rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=1)
- self._c._define(self._concrete_type, src, rcb)
- def __getattr__(self, attr):
- if "_initializing" not in self.__dict__:
- raise RuntimeError(
- "ScriptModule has not been initialized, did you forget to call super's init?"
- )
- if self._initializing:
- return super().__getattr__(attr)
- # _modules check is before hasattr since modules are included as attributes in _c,
- # but we want to get the python wrapper from _modules instead of the raw _c object.
- if attr in self._modules:
- return self._modules[attr]
- elif self._c.hasattr(attr):
- return self._c.getattr(attr)
- elif self._c._has_method(attr):
- script_method = self._c._get_method(attr)
- # cache method so future calls do not go through __getattr__
- # to improve invocation performance
- self.__dict__[attr] = script_method
- return script_method
- return super().__getattr__(attr)
- def __setattr__(self, attr, value):
- if self._initializing:
- return super().__setattr__(attr, value)
- if attr in self._modules:
- self._modules[attr] = value
- elif self._c.hasattr(attr):
- self._c.setattr(attr, value)
- elif (
- hasattr(self, "_concrete_type")
- and attr in self._concrete_type.get_constants().keys()
- ):
- # TODO: we don't have _concrete_type set after load(), and in general we lose constant information.
- # We should encode constants as class type attributes (or something) so it persists across save/load.
- raise AttributeError(
- "Cannot mutate TorchScript constant value: '{}'. Value: '{}'".format(
- attr, value
- )
- )
- else:
- # We allow setting Python attributes on the ScriptModule, for
- # when people want to stash some convenience info on it.
- # TODO: it's possible that the following is confusing:
- # s = torch.jit.script(...)
- # s.python_attr = ...
- # s.save() <--- this doesn't have `python_attr`
- # It's fairly trivial to save enough info to warn in this case.
- return super().__setattr__(attr, value)
- def __copy__(self):
- return torch.jit._recursive.wrap_cpp_module(copy.copy(self._c))
- def __deepcopy__(self, memo):
- return torch.jit._recursive.wrap_cpp_module(copy.deepcopy(self._c, memo))
- # Python magic methods do method lookups on an object's class type, instead of looking up
- # the method defines on the class instance. In order to continue to expose the magic methods
- # of builtin-containers (ModuleList, Sequential, ModuleDict) to Python, we
- # define magic methods here as a shim to the correct attribute.
- def forward_magic_method(self, method_name, *args, **kwargs):
- self_method = getattr(self, method_name)
- if getattr(self_method, "__func__", None) == getattr(
- RecursiveScriptModule, method_name
- ):
- raise NotImplementedError()
- return self_method(*args, **kwargs)
- def __iter__(self):
- return self.forward_magic_method("__iter__")
- def __getitem__(self, idx):
- return self.forward_magic_method("__getitem__", idx)
- def __len__(self):
- return self.forward_magic_method("__len__")
- def __contains__(self, key):
- return self.forward_magic_method("__contains__", key)
- # dir is defined by the base nn.Module, so instead of throwing if
- # it is not overridden, we call into the nn.Module __dir__ method
- def __dir__(self):
- self_method = self.__dir__
- if self_method.__func__ == _get_function_from_type( # type: ignore[attr-defined]
- RecursiveScriptModule, "__dir__"
- ):
- return super().__dir__()
- return self_method()
- # to resolve bool(value), Python looks if __bool__ is defined then __iter__
- # is defined then returns true for classes. Since __iter__() on this
- # class throws if it isn't overridden, we define __bool__ to preserve default behavior
- def __bool__(self):
- self_method = self.__bool__
- if self_method.__func__ == _get_function_from_type( # type: ignore[attr-defined]
- RecursiveScriptModule, "__bool__"
- ):
- return True
- return self_method()
- def _replicate_for_data_parallel(self):
- # we have to initialize ScriptModule properly so that
- # it works with pybind11
- def init_fn(script_module):
- # Don't do anything here, we'll initialize the ScriptModule below
- return
- return RecursiveScriptModule._construct(
- self._c._replicate_for_data_parallel(), init_fn
- )
- # Need to copy all RecursiveScriptModule methods to ScriptModule.
- #
- # This is because `super().foo()` does not use
- # `__getattr__` to look up `foo`. So we need to make each method available on
- # the ScriptModule manually.
- for name, item in RecursiveScriptModule.__dict__.items():
- if not callable(item) and not isinstance(item, property):
- continue
- if name.startswith("__") or hasattr(ScriptModule, name):
- continue
- # We can copy over the implementation wholesale because besides the
- # `super()` thing above, ScriptModule behaves exactly like
- # RecursiveScriptModule
- setattr(ScriptModule, name, item)
- def _get_methods(cls):
- import inspect
- # In Python 3 unbound methods are functions, but in Python 2 they are methods
- return inspect.getmembers(
- cls, predicate=lambda x: inspect.isfunction(x) or inspect.ismethod(x)
- )
- _compiled_methods_allowlist = {
- "forward",
- "register_buffer",
- "register_parameter",
- "register_module",
- "add_module",
- "_apply",
- "apply",
- "cuda",
- "cpu",
- "to",
- "type",
- "float",
- "double",
- "half",
- "state_dict",
- "_save_to_state_dict",
- "load_state_dict",
- "_load_from_state_dict",
- "_named_members",
- "parameters",
- "named_parameters",
- "buffers",
- "named_buffers",
- "children",
- "named_children",
- "modules",
- "named_modules",
- "zero_grad",
- "share_memory",
- "_get_name",
- "extra_repr",
- "_slow_forward",
- "_tracing_name",
- "eval",
- "train",
- "get_extra_state",
- "set_extra_state"
- }
- def _make_fail(name):
- def fail(self, *args, **kwargs):
- raise RuntimeError(name + " is not supported on ScriptModules")
- return fail
- for name, method in _get_methods(torch.nn.Module):
- if name.startswith("__"):
- continue
- if (
- name not in RecursiveScriptModule.__dict__
- and name not in _compiled_methods_allowlist
- ):
- setattr(RecursiveScriptModule, method.__name__, _make_fail(name))
- else:
- # TODO MAKE SURE THAT DISABLING WORKS
- class RecursiveScriptClass: # type: ignore[no-redef]
- pass
- class ScriptModule(torch.nn.Module): # type: ignore[no-redef]
- def __init__(self, arg=None):
- super().__init__()
- class RecursiveScriptModule(ScriptModule): # type: ignore[no-redef]
- def __init__(self, arg=None):
- super().__init__()
- def call_prepare_scriptable_func_impl(obj, memo):
- if not isinstance(obj, torch.nn.Module):
- return obj
- obj_id = id(obj)
- # If obj_id is in memo, obj has already been prepared or is being
- # prepared in another call up the stack.
- if obj_id in memo:
- return memo[id(obj)]
- obj = obj.__prepare_scriptable__() if hasattr(obj, '__prepare_scriptable__') else obj # type: ignore[operator]
- # Record obj in memo to avoid infinite recursion in the case of cycles in the module
- # hierarchy when recursing below.
- memo[obj_id] = obj
- new_obj_dict = {}
- for name, sub_module in obj.__dict__.items():
- if name == '_modules':
- for k, v in sub_module.items():
- sub_module[k] = call_prepare_scriptable_func_impl(v, memo)
- new_obj_dict[name] = sub_module
- elif isinstance(sub_module, torch.nn.Module) and not isinstance(sub_module, ScriptModule):
- new_obj_dict[name] = call_prepare_scriptable_func_impl(sub_module, memo)
- else:
- new_obj_dict[name] = sub_module
- for k, v in new_obj_dict.items():
- obj.__dict__[name] = v
- return obj
- def call_prepare_scriptable_func(obj):
- memo: Dict[int, torch.nn.Module] = {}
- return call_prepare_scriptable_func_impl(obj, memo)
- def create_script_dict(obj):
- """
- Create a ``torch._C.ScriptDict`` instance with the data from ``obj``.
- Args:
- obj (dict): The Python dictionary that is used to initialize the ``ScriptDict``
- returned by this function.
- Returns:
- An instance of ``torch._C.ScriptDict`` that has the same data as ``obj``
- and can be passed between Python and TorchScript with reference semantics and
- zero copy overhead.
- """
- return torch._C.ScriptDict(obj) # type: ignore[attr-defined]
- def create_script_list(obj, type_hint=None):
- """
- Create a ``torch._C.ScriptList`` instance with the data from ``obj``.
- Args:
- obj (dict): The Python list that is used to initialize the ``ScriptList``
- returned by this function.
- Returns:
- An instance of ``torch._C.ScriptList`` that has the same data as ``obj``
- and can be passed between Python and TorchScript with reference semantics and
- zero copy overhead.
- """
- return torch._C.ScriptList(obj) # type: ignore[attr-defined]
- def script(obj, optimize=None, _frames_up=0, _rcb=None,
- example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None):
- r"""
- Scripting a function or ``nn.Module`` will inspect the source code, compile
- it as TorchScript code using the TorchScript compiler, and return a :class:`ScriptModule` or
- :class:`ScriptFunction`. TorchScript itself is a subset of the Python language, so not all
- features in Python work, but we provide enough functionality to compute on
- tensors and do control-dependent operations. For a complete guide, see the
- :ref:`language-reference`.
- Scripting a dictionary or list copies the data inside it into a TorchScript instance than can be
- subsequently passed by reference between Python and TorchScript with zero copy overhead.
- ``torch.jit.script`` can be used as a function for modules, functions, dictionaries and lists
- and as a decorator ``@torch.jit.script`` for :ref:`torchscript-classes` and functions.
- Args:
- obj (Callable, class, or nn.Module): The ``nn.Module``, function, class type,
- dictionary, or list to compile.
- example_inputs (Union[List[Tuple], Dict[Callable, List[Tuple]], None]): Provide example inputs
- to annotate the arguments for a function or ``nn.Module``.
- Returns:
- If ``obj`` is ``nn.Module``, ``script`` returns
- a :class:`ScriptModule` object. The returned :class:`ScriptModule` will
- have the same set of sub-modules and parameters as the
- original ``nn.Module``. If ``obj`` is a standalone function,
- a :class:`ScriptFunction` will be returned. If ``obj`` is a ``dict``, then
- ``script`` returns an instance of `torch._C.ScriptDict`. If ``obj`` is a ``list``,
- then ``script`` returns an instance of `torch._C.ScriptList`.
- **Scripting a function**
- The ``@torch.jit.script`` decorator will construct a :class:`ScriptFunction`
- by compiling the body of the function.
- Example (scripting a function):
- .. testcode::
- import torch
- @torch.jit.script
- def foo(x, y):
- if x.max() > y.max():
- r = x
- else:
- r = y
- return r
- print(type(foo)) # torch.jit.ScriptFunction
- # See the compiled graph as Python code
- print(foo.code)
- # Call the function using the TorchScript interpreter
- foo(torch.ones(2, 2), torch.ones(2, 2))
- .. testoutput::
- :hide:
- ...
- ****Scripting a function using example_inputs**
- Example inputs can be used to annotate a function arguments.
- Example (annotating a function before scripting):
- .. testcode::
- import torch
- def test_sum(a, b):
- return a + b
- # Annotate the arguments to be int
- scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)])
- print(type(scripted_fn)) # torch.jit.ScriptFunction
- # See the compiled graph as Python code
- print(scripted_fn.code)
- # Call the function using the TorchScript interpreter
- scripted_fn(20, 100)
- .. testoutput::
- :hide:
- ...
- **Scripting an nn.Module**
- Scripting an ``nn.Module`` by default will compile the ``forward`` method and recursively
- compile any methods, submodules, and functions called by ``forward``. If a ``nn.Module`` only uses
- features supported in TorchScript, no changes to the original module code should be necessary. ``script``
- will construct :class:`ScriptModule` that has copies of the attributes, parameters, and methods of
- the original module.
- Example (scripting a simple module with a Parameter):
- .. testcode::
- import torch
- class MyModule(torch.nn.Module):
- def __init__(self, N, M):
- super().__init__()
- # This parameter will be copied to the new ScriptModule
- self.weight = torch.nn.Parameter(torch.rand(N, M))
- # When this submodule is used, it will be compiled
- self.linear = torch.nn.Linear(N, M)
- def forward(self, input):
- output = self.weight.mv(input)
- # This calls the `forward` method of the `nn.Linear` module, which will
- # cause the `self.linear` submodule to be compiled to a `ScriptModule` here
- output = self.linear(output)
- return output
- scripted_module = torch.jit.script(MyModule(2, 3))
- Example (scripting a module with traced submodules):
- .. testcode::
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- class MyModule(nn.Module):
- def __init__(self):
- super().__init__()
- # torch.jit.trace produces a ScriptModule's conv1 and conv2
- self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
- self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))
- def forward(self, input):
- input = F.relu(self.conv1(input))
- input = F.relu(self.conv2(input))
- return input
- scripted_module = torch.jit.script(MyModule())
- To compile a method other than ``forward`` (and recursively compile anything it calls), add
- the :func:`@torch.jit.export <torch.jit.export>` decorator to the method. To opt out of compilation
- use :func:`@torch.jit.ignore <torch.jit.ignore>` or :func:`@torch.jit.unused <torch.jit.unused>`.
- Example (an exported and ignored method in a module)::
- import torch
- import torch.nn as nn
- class MyModule(nn.Module):
- def __init__(self):
- super().__init__()
- @torch.jit.export
- def some_entry_point(self, input):
- return input + 10
- @torch.jit.ignore
- def python_only_fn(self, input):
- # This function won't be compiled, so any
- # Python APIs can be used
- import pdb
- pdb.set_trace()
- def forward(self, input):
- if self.training:
- self.python_only_fn(input)
- return input * 99
- scripted_module = torch.jit.script(MyModule())
- print(scripted_module.some_entry_point(torch.randn(2, 2)))
- print(scripted_module(torch.randn(2, 2)))
- Example ( Annotating forward of nn.Module using example_inputs)::
- import torch
- import torch.nn as nn
- from typing import NamedTuple
- class MyModule(NamedTuple):
- result: List[int]
- class TestNNModule(torch.nn.Module):
- def forward(self, a) -> MyModule:
- result = MyModule(result=a)
- return result
- pdt_model = TestNNModule()
- # Runs the pdt_model in eager model with the inputs provided and annotates the arguments of forward
- scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })
- # Run the scripted_model with actual inputs
- print(scripted_model([20]))
- """
- global type_trace_db
- if not _enabled:
- return obj
- if optimize is not None:
- warnings.warn(
- "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead"
- )
- # No-op for modules, functions, class instances that are already scripted
- if isinstance(obj, RecursiveScriptClass):
- return obj
- if isinstance(obj, ScriptModule):
- return obj
- if isinstance(obj, ScriptFunction):
- return obj
- if example_inputs:
- # If MonkeyType is installed, enable profile directed type annotation
- # Check if example_inputs are defined and generate call traces
- # for the method by running eager mode version of the method with
- # the provide example inputs. This logs all the traces in type_trace_db
- type_trace_db = JitTypeTraceStore()
- if monkeytype_trace:
- monkeytype_config = JitTypeTraceConfig(type_trace_db)
- with monkeytype_trace(monkeytype_config):
- if isinstance(example_inputs, Dict):
- # If the obj is an nn.Module or a class, then each method is
- # executed with the arguments provided in the example inputs.
- # example inputs here will be of type Dict(class.method, (arguments))
- # This is used to infer type annotations for those methods
- # which are not called directly under the hood of monkeytype.
- for module, example_input in example_inputs.items():
- for example in example_input:
- module(*example)
- elif isinstance(example_inputs, List):
- for examples in example_inputs:
- obj(*examples)
- else:
- raise ValueError("Error: Unable to infer types. Please format the inputs to type `List[Tuple]`"
- " or `Dict[Callable, List[Tuple]]` to be run with MonkeyType.")
- else:
- warnings.warn("Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType "
- "to enable Profile-Directed Typing in TorchScript. Refer to "
- "https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. ")
- if isinstance(obj, torch.nn.Module):
- obj = call_prepare_scriptable_func(obj)
- return torch.jit._recursive.create_script_module(
- obj, torch.jit._recursive.infer_methods_to_compile
- )
- if isinstance(obj, dict):
- return create_script_dict(obj)
- if isinstance(obj, list):
- return create_script_list(obj)
- if inspect.isclass(obj):
- qualified_name = _qualified_name(obj)
- # If this type is a `nn.Module` subclass, they probably meant to pass
- # an instance instead of a Module
- if issubclass(obj, torch.nn.Module):
- raise RuntimeError(
- "Type '{}' cannot be compiled since it inherits"
- " from nn.Module,"
- " pass an instance instead".format(obj)
- )
- # Enums are automatically usable in TorchScript, explicitly scripting
- # is not necessary, but not harmful either.
- if issubclass(obj, enum.Enum):
- return obj
- if not _is_new_style_class(obj):
- raise RuntimeError(
- "TorchScript classes must be new-style classes. "
- "Please inherit from 'object'."
- )
- if len(obj.mro()) > 2:
- raise RuntimeError(
- "TorchScript classes does not support inheritance yet. "
- "Please directly inherit from 'object'."
- )
- if _rcb is None:
- _rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1)
- _compile_and_register_class(obj, _rcb, qualified_name)
- return obj
- elif inspect.isfunction(obj) or inspect.ismethod(obj):
- qualified_name = _qualified_name(obj)
- # this is a decorated fn, and we need to the underlying fn and its rcb
- if hasattr(obj, "__script_if_tracing_wrapper"):
- obj = obj.__original_fn # type: ignore[union-attr]
- _rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
- # some functions are explicitly marked as not supported in script mode
- if hasattr(obj, "__script_unsupported"):
- raise RuntimeError("TorchScript error: " + obj.__script_unsupported)
- _check_directly_compile_overloaded(obj)
- maybe_already_compiled_fn = _try_get_jit_cached_function(obj)
- if maybe_already_compiled_fn:
- return maybe_already_compiled_fn
- ast = get_jit_def(obj, obj.__name__)
- if _rcb is None:
- _rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
- fn = torch._C._jit_script_compile(
- qualified_name, ast, _rcb, get_default_args(obj)
- )
- # Forward docstrings
- fn.__doc__ = obj.__doc__
- # Allow torch.compile() to inline
- fn._torchdynamo_inline = obj # type: ignore[attr-defined]
- _set_jit_function_cache(obj, fn)
- return fn
- else:
- return torch.jit._recursive.create_script_class(obj)
- # overloads are registered in _jit_internal and compiled here so that _overload
- # can be used in nn/functional.py without an import cycle
- def _check_overload_defaults(impl_defaults, overload_defaults, loc):
- for name, overload_value in overload_defaults.items():
- if name not in impl_defaults or impl_defaults[name] != overload_value:
- raise torch.jit.frontend.FrontendError(
- loc,
- "Default parameters on overloads do not affect the runtime so they "
- "must equal to the default parameter on the implementation function. Found on "
- "parameter {name}".format(name=name),
- )
- def _compile_function_with_overload(overload_fn, qual_name, impl_fn):
- overload_decl = get_jit_def(overload_fn, overload_fn.__name__).decl()
- overload_signature = torch.jit.annotations.get_signature(
- overload_fn, None, None, inspect.ismethod(overload_fn)
- )
- impl_ast = get_jit_def(impl_fn, impl_fn.__name__)
- overload_defaults = get_default_args(overload_fn)
- implementation_defaults = get_default_args(impl_fn)
- _rcb = _jit_internal.createResolutionCallbackFromClosure(impl_fn)
- _check_overload_defaults(
- implementation_defaults, overload_defaults, overload_decl.range()
- )
- fn = torch._C._jit_script_compile_overload(
- qual_name,
- overload_decl,
- impl_ast,
- _rcb,
- implementation_defaults,
- overload_signature,
- )
- return fn
- def _get_overloads(obj):
- # check for cached compiled fns
- existing_compiled_fns = _try_get_jit_cached_overloads(obj)
- qual_name = _qualified_name(obj)
- uncompiled_overloads = _jit_internal._get_fn_overloads(qual_name)
- if uncompiled_overloads is None:
- return existing_compiled_fns
- if obj in uncompiled_overloads:
- raise RuntimeError(_jit_internal.get_overload_no_implementation_error_message(
- 'function', obj))
- compiled_fns = []
- for overload_fn in uncompiled_overloads:
- compiled_fns.append(
- _compile_function_with_overload(overload_fn, qual_name, obj)
- )
- if existing_compiled_fns:
- compiled_fns = existing_compiled_fns + compiled_fns
- # cache compilation, remove information stored to do compilation
- _set_jit_overload_cache(obj, compiled_fns)
- _jit_internal._clear_fn_overloads(qual_name)
- return compiled_fns
- def _check_directly_compile_overloaded(obj):
- qual_name = _qualified_name(obj)
- if _jit_internal._get_fn_overloads(qual_name) or _try_get_jit_cached_overloads(obj):
- raise RuntimeError(
- "Function {} cannot be directly compiled because it"
- " is overloaded. It must be used in a context of a function"
- " where its inputs can determine which overload to call.".format(qual_name)
- )
- def interface(obj):
- if not inspect.isclass(obj):
- raise RuntimeError("interface must be applied to a class")
- if not _is_new_style_class(obj):
- raise RuntimeError("TorchScript interfaces must inherit from 'object'")
- # Expected MRO is:
- # User module
- # torch.nn.modules.module.Module
- # object
- is_module_interface = issubclass(obj, torch.nn.Module) and len(obj.mro()) == 3
- if not is_module_interface and len(obj.mro()) > 2:
- raise RuntimeError(
- "TorchScript interface does not support inheritance yet. "
- "Please directly inherit from 'object' or 'nn.Module'."
- )
- qualified_name = _qualified_name(obj)
- rcb = _jit_internal.createResolutionCallbackFromFrame(1)
- # if this type is a `nn.Module` subclass, generate a module interface type
- # instead of a class interface type; a module interface type only compiles
- # the user provided methods as part of the interface
- ast = get_jit_class_def(obj, obj.__name__)
- mangled_classname = torch._C._jit_script_interface_compile(
- qualified_name, ast, rcb, is_module_interface
- )
- obj.__torch_script_interface__ = mangled_classname
- return obj
- def _recursive_compile_class(obj, loc):
- _qual_name = _qualified_name(obj)
- # We're starting a new compilation, so update the error call stack in
- # case it fails
- error_stack = torch._C.CallStack(_qual_name, loc)
- rcb = _jit_internal.createResolutionCallbackForClassMethods(obj)
- return _compile_and_register_class(obj, rcb, _qual_name)
- CompilationUnit = torch._C.CompilationUnit
- set_module(CompilationUnit, "torch.jit")
- def pad(s: str, padding: int, offset: int = 0, char: str = ' '):
- if padding >= len(s):
- padding -= len(s)
- return ''.join([char for _ in range(padding + offset)]) + s
- class _ScriptProfileColumn:
- def __init__(self, header: str, alignment: int = 4, offset: int = 0):
- self.header = header
- self.alignment = alignment
- self.offset = offset
- self.rows: Dict[int, Any] = {}
- def add_row(self, lineno: int, value: Any):
- self.rows[lineno] = value
- def materialize(self):
- max_length = len(self.header)
- rows: List[Tuple[int, str]] = []
- for (key, value) in self.rows.items():
- cell = str(value)
- rows.append((key, cell))
- max_length = max(len(cell), max_length)
- if self.alignment > 0:
- padding = max_length + self.alignment
- padding -= padding % self.alignment
- else:
- padding = 0
- rows = [(key, pad(cell, padding, self.offset)) for key, cell in rows]
- return pad(self.header, padding, self.offset), rows
- class _ScriptProfileTable:
- def __init__(self, cols: List[_ScriptProfileColumn], source_range: List[int]):
- self.cols = cols
- self.source_range = source_range
- def dump_string(self):
- outputs: List[str] = []
- cells: List[Tuple[str, Dict[int, str]]] = []
- header_buffer = ''
- for col in self.cols:
- header, rows = col.materialize()
- header_buffer += header
- cells.append((header, dict(rows)))
- outputs.append(header_buffer)
- outputs.append(pad('', len(header_buffer), 0, '='))
- for line in self.source_range:
- row_buffer = ''
- for header, rows in cells:
- cell = rows.get(line)
- if cell is None:
- row_buffer += pad('', len(header))
- else:
- row_buffer += cell
- outputs.append(row_buffer)
- return '\n'.join(outputs)
- class _ScriptProfile:
- def __init__(self):
- self.profile = classes.profiling._ScriptProfile()
- def enable(self):
- self.profile.enable()
- def disable(self):
- self.profile.disable()
- def dump_string(self) -> str:
- outputs: List[str] = []
- for source_stats in self.profile._dump_stats():
- source_ref = source_stats.source()
- source_lines = source_ref.text().splitlines()
- dedent = min([len(line) - len(line.lstrip(' ')) for line in source_lines])
- source_lines = [line[dedent:] for line in source_lines]
- start_line = source_ref.starting_lineno()
- end_line = start_line + len(source_lines)
- source_range = range(start_line, end_line)
- lineno = _ScriptProfileColumn("Line #")
- hits = _ScriptProfileColumn("Hits")
- time_ns = _ScriptProfileColumn("Time (ns)")
- line_contents = _ScriptProfileColumn("Line Contents", 0, 1)
- stats = source_stats.line_map()
- for line in source_range:
- lineno.add_row(line, line)
- line_contents.add_row(line, source_lines[line - start_line])
- stat = stats.get(line)
- if stat is not None:
- hits.add_row(line, stat.count())
- time_ns.add_row(line, stat.duration_ns())
- table = _ScriptProfileTable([lineno, hits, time_ns, line_contents], list(source_range))
- outputs.append(table.dump_string())
- return '\n\n'.join(outputs)
- def dump(self):
- print(self.dump_string())
- def _unwrap_optional(x):
- assert x is not None, "Unwrapping null optional"
- return x
- _register_builtin(_unwrap_optional, "aten::_unwrap_optional")
- _register_builtin(_jit_internal.is_scripting, "aten::is_scripting")
- _register_builtin(has_torch_function, "aten::has_torch_function")
- _register_builtin(has_torch_function_unary, "aten::has_torch_function")
- _register_builtin(has_torch_function_variadic, "aten::has_torch_function")
|