123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932 |
- import inspect
- import torch
- import types
- import collections
- import textwrap
- import functools
- import warnings
- import sys
- from typing import Dict, List, Set, Type
- import torch._jit_internal as _jit_internal
- from torch._sources import fake_range
- from torch.jit.frontend import get_default_args, get_jit_class_def, get_jit_def, get_class_properties
- from torch.jit._builtins import _find_builtin
- from torch.jit._check import AttributeTypeIsSupportedChecker
- from torch.jit._state import _python_cu, _add_script_class, _get_script_class
- from torch.nn import Module
- ScriptMethodStub = collections.namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method'))
- PropertyStub = collections.namedtuple('PropertyStub', ('resolution_callback', 'def_'))
- # TODO: there should be a more principled way of doing this.
- ignored_attributes = [
- "_version",
- "_parameters",
- "_buffers",
- "_non_persistent_buffers_set",
- "_backward_hooks",
- "_backward_pre_hooks",
- "_forward_hooks",
- "_forward_hooks_with_kwargs",
- "_forward_pre_hooks",
- "_forward_pre_hooks_with_kwargs",
- "_state_dict_hooks",
- "_state_dict_pre_hooks",
- "_load_state_dict_pre_hooks",
- "_load_state_dict_post_hooks",
- "_modules",
- "_initializing",
- "dump_patches",
- ]
- def _compile_and_register_class(obj, rcb, qualified_name):
- script_class = _get_script_class(obj)
- if not script_class:
- ast = get_jit_class_def(obj, obj.__name__)
- defaults = torch.jit.frontend.get_default_args_for_class(obj)
- script_class = torch._C._jit_script_class_compile(qualified_name, ast, defaults, rcb)
- _add_script_class(obj, script_class)
- return script_class
- def make_stub(func, name):
- rcb = _jit_internal.createResolutionCallbackFromClosure(func)
- ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
- return ScriptMethodStub(rcb, ast, func)
- def make_stub_from_method(nn_module, method_name):
- func = getattr(nn_module, method_name)
- if isinstance(func, ScriptMethodStub):
- return func
- # Make sure the name present in the resulting AST will match the name
- # requested here. The only time they don't match is if you do something
- # like:
- # def _forward(self):
- # pass
- # forward = _forward
- # In this case, the actual function object will have the name `_forward`,
- # even though we requested a stub for `forward`.
- return make_stub(func, method_name)
- def make_stubs_from_exported_methods(mod):
- stubs = []
- for name in dir(mod):
- item = getattr(mod, name, None)
- if (
- _jit_internal.get_torchscript_modifier(item)
- is _jit_internal.FunctionModifiers.EXPORT
- ):
- stubs.append(make_stub_from_method(mod, name))
- return stubs
- def jit_ignored_properties(module):
- user_annotated_ignored_attributes = getattr(module, "__jit_ignored_attributes__", list())
- def get_properties_names(module):
- return {k for k, v in vars(module).items() if isinstance(v, property)}
- properties = get_properties_names(type(module))
- user_annoted_ignored_properties = set()
- for ignored_attr in user_annotated_ignored_attributes:
- if ignored_attr in properties:
- user_annoted_ignored_properties.add(ignored_attr)
- return user_annoted_ignored_properties
- # base types that can be constants
- # in addition, tuples and lists of these base types are also considered constants
- # If you edit this list, then you also need to edit the handlers in
- # ConstantValue in jit/script/init.cpp
- _constant_types = (bool, float, int, str, type(None), torch.device, torch.layout, torch.dtype)
- def _get_valid_constant(attr, v, owner_type):
- if isinstance(v, _constant_types):
- return v
- elif isinstance(v, (tuple, list)):
- return tuple(_get_valid_constant(attr, x, owner_type) for x in v)
- constants = ", ".join(torch.typename(typ) for typ in _constant_types)
- raise TypeError(textwrap.dedent("""
- '{}' object in attribute '{}.{}' is not a valid constant.
- Valid constants are:
- 1. a nn.ModuleList
- 2. a value of type {{{}}}
- 3. a list or tuple of (2)
- """.format(torch.typename(type(v)), owner_type, attr, constants)))
- class SourceContext(torch._C._jit_tree_views.SourceRangeFactory):
- def __init__(self, source, filename, file_lineno, leading_whitespace_len):
- super().__init__(source, filename, file_lineno, leading_whitespace_len)
- def infer_concrete_type_builder(nn_module, share_types=True):
- """
- Build a ConcreteModuleTypeBuilder from an nn.Module. This
- ConcreteModuleType doesn't have a JIT type associated with it yet, it
- must be filled in by the caller.
- """
- concrete_type_builder = torch._C.ConcreteModuleTypeBuilder(type(nn_module))
- if isinstance(nn_module, (torch.nn.ModuleDict)):
- concrete_type_builder.set_module_dict()
- if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential)):
- concrete_type_builder.set_module_list()
- if isinstance(nn_module, (torch.nn.ParameterList)):
- concrete_type_builder.set_parameter_list()
- if isinstance(nn_module, (torch.nn.ParameterDict)):
- concrete_type_builder.set_parameter_dict()
- def get_annotations(obj):
- if sys.version_info < (3, 10):
- return getattr(obj, '__annotations__', {})
- # In Python-3.10+ it is recommended to use inspect.get_annotations
- # See https://docs.python.org/3.10/howto/annotations.html
- # But also, in 3.10 annotations from base class are not inherited
- # by unannotated derived one, so they must be manually extracted
- annotations = inspect.get_annotations(obj)
- if len(annotations) > 0:
- return annotations
- cls = obj if isinstance(obj, type) else type(obj)
- if len(cls.__bases__) == 0:
- return {}
- return inspect.get_annotations(cls.__bases__[0])
- class_annotations = get_annotations(nn_module)
- if isinstance(nn_module, (torch.ao.quantization.QuantWrapper)):
- class_annotations = {}
- # Get user-annotated ignored attributes.
- user_annotated_ignored_attributes = getattr(nn_module, "__jit_ignored_attributes__", list())
- concrete_type_builder.add_ignored_attributes(user_annotated_ignored_attributes)
- ignored_properties = jit_ignored_properties(nn_module)
- # try to infer the type from type annotation or from the object itself
- def infer_type(name, item):
- # The forward function from Module is special; never use this annotations; we
- # need to infer type directly using JIT. I originally wanted to write
- # this test as isinstance(class_annotations[name], Callable) but
- # isinstance on typing things doesn't seem to work: isinstance(list, Callable)
- # is also true!
- inferred = False
- try:
- if name in class_annotations and class_annotations[name] != torch.nn.Module.__annotations__["forward"]:
- ann_to_type = torch.jit.annotations.ann_to_type(class_annotations[name], fake_range())
- attr_type = torch._C.InferredType(ann_to_type)
- elif isinstance(item, torch.jit.Attribute):
- ann_to_type = torch.jit.annotations.ann_to_type(item.type, fake_range())
- attr_type = torch._C.InferredType(ann_to_type)
- else:
- attr_type = torch._C._jit_try_infer_type(item)
- inferred = True
- except RuntimeError as re:
- raise RuntimeError(
- "Error inferring type for {name}: {item}: {re}".format(name=name, item=item, re=re)
- ) from re
- return attr_type, inferred
- added_names = set()
- for name, item in nn_module._parameters.items():
- if name in user_annotated_ignored_attributes:
- continue
- assert item is None or isinstance(item, torch.Tensor)
- attr_type, _ = infer_type(name, item)
- # We currently have the invariant in various places in our code
- # that parameters must be Tensors. However, the nn.Module API also
- # allows NoneType parameters. These parameters are not returned as
- # part of `parameters()` and its variants, but are available
- # through direct attribute access.
- concrete_type_builder.add_attribute(name, attr_type.type(), True, False)
- added_names.add(name)
- for name, item in nn_module._buffers.items():
- if name in user_annotated_ignored_attributes:
- continue
- assert item is None or isinstance(item, torch.Tensor)
- attr_type, _ = infer_type(name, item)
- concrete_type_builder.add_attribute(name, attr_type.type(), False, True)
- added_names.add(name)
- for name, item in nn_module._modules.items():
- if name in user_annotated_ignored_attributes:
- continue
- attr_type, _ = infer_type(name, item)
- if item is None:
- # Modules can be None. We don't have direct support for optional
- # Modules, so the register it as an NoneType attribute instead.
- concrete_type_builder.add_attribute(name, attr_type.type(), False, False)
- continue
- if attr_type.success():
- assert attr_type.type().is_interface_type()
- # if the type can be inferred, it should be a module interface type
- sub_concrete_type = torch._C.ConcreteModuleType.from_jit_type(attr_type.type())
- else:
- # otherwise we get the concrete module type for item and add it to concrete_type
- sub_concrete_type = get_module_concrete_type(item, share_types)
- concrete_type_builder.add_module(name, sub_concrete_type)
- added_names.add(name)
- # populate constants_set
- constants_set = set(getattr(nn_module, "__constants__", ()))
- # Constants annotated via `Final[T]` rather than being added to `__constants__`
- for name, ann in class_annotations.items():
- if torch._jit_internal.is_final(ann):
- constants_set.add(name)
- for name in constants_set:
- if name in added_names:
- # TODO: We should really error in this case, but its bc-breaking so
- # we need to warn for at least one release
- if name in nn_module._modules:
- hint = "submodule"
- elif name in nn_module._buffers:
- hint = "buffer"
- elif name in nn_module._parameters:
- hint = "parameter"
- else:
- raise AssertionError("added_names must be submodule, parameter, or buffer")
- warnings.warn("'{}' was found in ScriptModule constants, "
- " but it is a non-constant {}. Consider removing it.".format(name, hint))
- continue
- if not hasattr(nn_module, name):
- # TODO: We should really error in this case, but its bc-breaking so
- # we need to warn for at least one release
- warnings.warn("'{}' was found in ScriptModule constants, "
- "but was not actually set in __init__. "
- "Consider removing it.".format(name))
- continue
- value = getattr(nn_module, name)
- concrete_type_builder.add_constant(name, _get_valid_constant(name, value, type(nn_module).__name__))
- added_names.add(name)
- # populate overloads
- overloads = getattr(nn_module, "__overloads__", {})
- # update with any annotated overloads
- overloads.update(get_overload_name_mapping(get_overload_annotations(nn_module, ignored_properties)))
- for name, overloaded_names in overloads.items():
- concrete_type_builder.add_overload(name, overloaded_names)
- for name, value in nn_module.__dict__.items():
- if name in ignored_attributes or name.startswith("__"):
- # Python objects have lots of random attributes attached to them;
- # PyTorch adds a few more. Prevent these from getting compiled.
- continue
- if name in user_annotated_ignored_attributes:
- continue
- if name in added_names:
- # Don't re-add anything we already added
- continue
- isoverloadpacket = isinstance(value, torch._ops.OpOverloadPacket)
- if isoverloadpacket:
- value = value.op
- # Handle Python function attributes
- if inspect.isfunction(value):
- try:
- scripted_fn = torch.jit.script(value)
- concrete_type_builder.add_function_attribute(
- name,
- torch._C._jit_try_infer_type(scripted_fn).type(),
- value)
- except Exception as e:
- # If we fail to script the function, it isn't a hard error.
- # Instead, we will add it to the list of attributes we failed
- # to convert, with the compilation error.
- hint = ("(This function exists as an attribute on the Python module, "
- "but we failed to compile it to a TorchScript function. "
- "\nThe error stack is reproduced here:\n{}").format(e)
- concrete_type_builder.add_failed_attribute(name, hint)
- pass
- continue
- # Handle calls to builtin functions (either bespoke builtins from torch.jit._builtins or
- # a call to an aten function like torch.add)
- builtin_symbol_name = _find_builtin(value)
- if builtin_symbol_name:
- concrete_type_builder.add_builtin_function(name, builtin_symbol_name)
- continue
- # Handle Script function attributes
- if isinstance(value, torch.jit.ScriptFunction):
- concrete_type_builder.add_function_attribute(
- name,
- torch._C._jit_try_infer_type(value).type(),
- value)
- continue
- # If we got here, this is a regular "data" attribute, add it to the concrete type
- attr_type, inferred = infer_type(name, value)
- if attr_type.success():
- concrete_type_builder.add_attribute(name, attr_type.type(), False, False)
- else:
- # TODO: could add more detail here. For example, what the user should do
- # when the pytype is `list` or `NoneType`
- inferred_msg = "Its type was inferred; try adding a type annotation for the attribute." if inferred else ""
- additional_info = f"{attr_type.reason()}. {inferred_msg}"
- hint = "(This attribute exists on the Python module, " \
- f"but we failed to convert Python type: '{torch.typename(type(value))}' " \
- f"to a TorchScript type. {additional_info})"
- concrete_type_builder.add_failed_attribute(name, hint)
- # add hooks to concrete type
- for hook in nn_module._forward_hooks.values():
- concrete_type_builder.add_forward_hook(hook)
- for pre_hook in nn_module._forward_pre_hooks.values():
- concrete_type_builder.add_forward_pre_hook(pre_hook)
- return concrete_type_builder
- class ConcreteTypeStore:
- type_store: Dict[Type[Module], List[torch._C.ConcreteModuleType]]
- methods_compiled: Set[torch._C.ConcreteModuleType]
- def __init__(self):
- # Python module type => List[ConcreteModuleType)]
- self.type_store = {}
- # ConcreteTypes that have had their methods already compiled
- self.methods_compiled = set()
- def get_or_create_concrete_type(self, nn_module):
- """
- Infer a ConcreteType from this `nn.Module` instance. Underlying JIT
- types are re-used if possible.
- """
- concrete_type_builder = infer_concrete_type_builder(nn_module)
- nn_module_type = type(nn_module)
- if nn_module_type not in self.type_store:
- self.type_store[nn_module_type] = []
- # Search the type store for an already-available JIT type
- known_types = self.type_store[nn_module_type]
- for known_type in known_types:
- if known_type.equals(concrete_type_builder):
- return known_type
- # We didn't find anything; generate a new JIT type from this concrete type
- concrete_type = concrete_type_builder.build()
- self.type_store[nn_module_type].append(concrete_type)
- return concrete_type
- concrete_type_store = ConcreteTypeStore()
- def create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs):
- method_defs = [m.def_ for m in method_stubs]
- method_rcbs = [m.resolution_callback for m in method_stubs]
- method_defaults = [get_default_args(m.original_method) for m in method_stubs]
- property_defs = [p.def_ for p in property_stubs]
- property_rcbs = [p.resolution_callback for p in property_stubs]
- concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
- def create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs):
- hook_defs = [h.def_ for h in hook_stubs]
- hook_rcbs = [h.resolution_callback for h in hook_stubs]
- pre_hook_defs = [h.def_ for h in pre_hook_stubs]
- pre_hook_rcbs = [h.resolution_callback for h in pre_hook_stubs]
- concrete_type._create_hooks(hook_defs, hook_rcbs, pre_hook_defs, pre_hook_rcbs)
- def get_module_concrete_type(nn_module, share_types=True):
- """
- Gets a concrete type for nn_modules. If share_types is True, the concrete
- type is fetched from concrete_type_store. If it is False, a new concrete type
- is created without first searching concrete_type_store.
- Args:
- nn_module: The original Python nn.Module that we are creating a ScriptModule for.
- share_types = Whether to share underlying JIT types between modules (if possible).
- Returns:
- A concrete type for nn_module.
- """
- assert isinstance(nn_module, Module)
- if isinstance(nn_module, torch.jit.ScriptModule) and \
- hasattr(nn_module, "_concrete_type"):
- return nn_module._concrete_type
- if share_types:
- # Look into the store of cached JIT types
- concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
- else:
- # Get a concrete type directly, without trying to re-use an existing JIT
- # type from the type store.
- concrete_type_builder = infer_concrete_type_builder(nn_module, share_types)
- concrete_type_builder.set_poisoned()
- concrete_type = concrete_type_builder.build()
- return concrete_type
- def create_script_class(obj):
- """
- Create and return a RecursiveScriptClass instance from a Python object.
- Arguments:
- obj: A Python object.
- """
- qualified_class_name = _jit_internal._qualified_name(type(obj))
- rcb = _jit_internal.createResolutionCallbackForClassMethods(type(obj))
- # Script the type of obj if it hasn't already been scripted.
- _compile_and_register_class(type(obj), rcb, qualified_class_name)
- class_ty = _python_cu.get_class(qualified_class_name)
- # Create an empty torch._C.ScriptObject with the scripted type.
- cpp_object = torch._C._create_object_with_type(class_ty)
- # Copy all of the attributes over to the torch._C.ScriptObject.
- for name, value in obj.__dict__.items():
- cpp_object.setattr(name, value)
- # Wrap the torch._C.ScriptObject in a RecursiveScriptClass instance.
- return wrap_cpp_class(cpp_object)
- def create_script_module(nn_module, stubs_fn, share_types=True, is_tracing=False):
- """
- Creates a new ScriptModule from an nn.Module
- Args:
- nn_module: The original Python nn.Module that we are creating a ScriptModule for.
- stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile.
- share_types: Whether to share underlying JIT types between modules (if possible).
- NOTE: Only set to False this when we cannot guarantee type sharing will work
- correctly. This only happens today for traced modules, where the same
- module can produce different traced methods depending on the inputs.
- is_tracing: Whether this function is called during tracing or scripting. If tracing,
- we don't need to do AttributeTypeIsSupportedChecker because all the unsupported
- attributes will be baked as constant in the tracing graph. In addition,
- this check significantly slows down the traced modules when the module size is big.
- """
- assert not isinstance(nn_module, torch.jit.RecursiveScriptModule)
- check_module_initialized(nn_module)
- concrete_type = get_module_concrete_type(nn_module, share_types)
- if not is_tracing:
- AttributeTypeIsSupportedChecker().check(nn_module)
- return create_script_module_impl(nn_module, concrete_type, stubs_fn)
- def create_script_module_impl(nn_module, concrete_type, stubs_fn):
- """
- Convert an nn.Module to a RecursiveScriptModule.
- Args:
- nn_module: The original Python nn.Module that we are creating a ScriptModule for.
- concrete_type: The fully initialized ConcreteType of the module.
- stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile.
- """
- cpp_module = torch._C._create_module_with_type(concrete_type.jit_type)
- method_stubs = stubs_fn(nn_module)
- property_stubs = get_property_stubs(nn_module)
- hook_stubs, pre_hook_stubs = get_hook_stubs(nn_module)
- user_annotated_ignored_attributes = getattr(nn_module, "__jit_ignored_attributes__", list())
- ignored_properties = jit_ignored_properties(nn_module)
- def init_fn(script_module):
- # Initialize the ScriptModule:
- # 1. Copy the attributes/parameters/buffers from the original `nn_module` to the new ScriptModule.
- for name, (attr_type, is_param) in concrete_type.get_attributes().items():
- orig_value = getattr(nn_module, name)
- orig_value = orig_value.value if isinstance(orig_value, torch.jit.Attribute) else orig_value
- cpp_module.setattr(name, orig_value)
- # 2. Copy the submodules from the original `nn_module` to the new ScriptModule,
- # recursively scripting them.
- for name, sub_concrete_type in concrete_type.get_modules():
- orig_value = getattr(nn_module, name)
- assert isinstance(orig_value, Module), "Expected Module but got {}".format(type(orig_value))
- module_type = sub_concrete_type.jit_type
- if isinstance(module_type, torch._C.InterfaceType):
- # use the interface inference rule to compile the module
- scripted = interface_script(module_type, orig_value)
- elif isinstance(orig_value, torch.jit.ScriptModule):
- scripted = orig_value
- else:
- # always reuse the provided stubs_fn to infer the methods to compile
- scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
- cpp_module.setattr(name, scripted)
- script_module._modules[name] = scripted
- # 3. Copy @ignored/@unused methods and attrs from the original `nn_module` to the new ScriptModule.
- # This ensures we can access these Python methods on the ScriptModule.
- for name in dir(nn_module):
- if name in ignored_properties:
- continue
- item = getattr(nn_module, name, None)
- if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item):
- unbound_function = getattr(nn_module, name).__func__
- bound_method = unbound_function.__get__(script_module)
- setattr(script_module, name, bound_method)
- elif concrete_type.is_ignored_attribute(name):
- setattr(script_module, name, item)
- # For convenience, attach the concrete type to the new ScriptModule
- script_module._concrete_type = concrete_type
- # Actually create the ScriptModule, initializing it with the function we just defined
- script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
- # Compile methods if necessary
- if concrete_type not in concrete_type_store.methods_compiled:
- create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
- # Create hooks after methods to ensure no name collisions between hooks and methods.
- # If done before, hooks can overshadow methods that aren't exported.
- create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs)
- torch._C._run_emit_module_hook(cpp_module)
- concrete_type_store.methods_compiled.add(concrete_type)
- # Copy the forward hooks and pre-hooks to the new ScriptModule
- # to allow the hooks to be run from eager as ScriptFunctions
- for idx, fn in enumerate(script_module._c._get_forward_pre_hooks()):
- script_module._forward_pre_hooks[idx] = fn
- for idx, fn in enumerate(script_module._c._get_forward_hooks()):
- script_module._forward_hooks[idx] = fn
- # Special handling so methods like __len__ work in script methods on classes derived from containers
- if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)) and \
- '__len__' not in cpp_module._method_names():
- script_module.define("def __len__(self):\n return {}\n".format(len(nn_module)))
- if isinstance(nn_module, torch.nn.ModuleDict) and \
- '__contains__' not in cpp_module._method_names():
- if len(nn_module.keys()):
- keys = repr(list(nn_module.keys()))
- script_module.define("def __contains__(self, key: str):\n return key in {}\n".format(keys))
- else:
- script_module.define("def __contains__(self, key: str):\n return False\n")
- # Make the compiled methods available to the Python ScriptModule class.
- for method_stub in method_stubs:
- if method_stub.original_method is None:
- # define()'d methods don't have an Python original_method, so we
- # don't need to do any Python re-wrapping stuff
- continue
- name = method_stub.original_method.__name__
- if name != method_stub.def_.name().name:
- # TODO: Why skip this? Because @torch.jit._overload_method will
- # mangle the name of the function.
- continue
- script_method = cpp_module._get_method(name)
- # Wrap the original to propagate docstrings and such.
- # TODO: we don't currently do this functions that are recursively
- # compiled, we should.
- wrapped_script_method = functools.wraps(method_stub.original_method)(script_method)
- # Add the methods to the script_module directly. This ensures they will
- # be found first when `name` is looked up (as opposed to the stubs or
- # nn.Module.forward)
- script_module.__dict__[name] = wrapped_script_method
- # Make module properties available on the Python ScriptModule class.
- for property_stub in property_stubs:
- property_name = property_stub.def_.name().name
- fget = cpp_module._get_method(property_stub.def_.getter_name().name)
- # Setter is optional, so it may not exist.
- setter_name = property_stub.def_.setter_name()
- fset = cpp_module._get_method(setter_name.name) if setter_name else None
- script_module.__dict__[property_name] = property(property_name, fget, fset) # type: ignore[arg-type]
- # copy over python methods to script module if they aren't defined on the script module
- # this is currently an internal api used only on module containers
- for name in dir(nn_module):
- if name in ignored_properties:
- continue
- item = getattr(nn_module, name, None)
- if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.COPY_TO_SCRIPT_WRAPPER:
- add_python_attr_to_scripted_model(script_module, nn_module, name)
- return script_module
- # We define shims of certain attributes on the RecursiveScriptModule to support
- # magic methods. To check if a script model defines an attribute we need
- # to also check that the attribute is not the shim
- def script_model_defines_attr(script_model, attr):
- script_attr = getattr(script_model, attr, None)
- if script_attr is None:
- return False
- default_attr = getattr(torch.jit.RecursiveScriptModule, attr, None)
- if default_attr is None:
- return False
- return script_attr != default_attr
- def add_python_attr_to_scripted_model(script_model, orig, attr):
- if hasattr(orig, attr) and script_model_defines_attr(script_model, attr):
- setattr(script_model, attr, getattr(orig, attr))
- def get_overload_annotations(mod, jit_ignored_properties):
- # original function => [(mangled overload name, overload function)]
- overloads = {}
- for name in dir(type(mod)):
- if name in jit_ignored_properties:
- continue
- item = getattr(mod, name, None)
- if not callable(item):
- continue
- # builtin functions like repr() in python 2 do not have __module__ defined
- if hasattr(item, "__module__") and item.__module__ is not None:
- method_overloads = _jit_internal._get_overloaded_methods(item, mod.__class__)
- if method_overloads is None:
- continue
- if item.__func__ in method_overloads:
- raise RuntimeError(_jit_internal.get_overload_no_implementation_error_message(
- 'method', item.__func__))
- names = [name + "__" + str(i) for i in range(len(method_overloads))]
- overloads[item] = list(zip(names, method_overloads))
- return overloads
- def get_overload_name_mapping(overload_info):
- # Same format as __overloads__
- # original function => [overload names]
- overload_name_mappings: Dict[str, List[str]] = {}
- for orig_fn, overloads in overload_info.items():
- original_name = orig_fn.__name__
- if original_name not in overload_name_mappings:
- overload_name_mappings[original_name] = []
- for overload_name, _ in overloads:
- overload_name_mappings[original_name].append(overload_name)
- return overload_name_mappings
- def _check_no_signature(func):
- signature = torch.jit.annotations.get_signature(func, None, fake_range(), inspect.ismethod(func))
- if signature is None:
- qual_name = _jit_internal._qualified_name(func)
- raise RuntimeError("Must explicitly add type annotations to overloaded functions: {}".format(qual_name))
- def make_stubs_for_overloads(overload_info):
- overload_stubs = []
- for orig_fn, overloads in overload_info.items():
- orig_ast = get_jit_def(orig_fn, orig_fn.__name__, self_name="RecursiveScriptModule")
- for overload_name, overload_fn in overloads:
- _check_no_signature(overload_fn)
- over_ast = get_jit_def(overload_fn, overload_fn.__name__, self_name="RecursiveScriptModule")
- new_ast = torch._C._replace_overloaded_method_decl(over_ast.decl(), orig_ast, overload_name)
- _rcb = _jit_internal.createResolutionCallbackFromClosure(orig_fn)
- overload_stubs.append(ScriptMethodStub(_rcb, new_ast, overload_fn))
- return overload_stubs
- def check_module_initialized(mod):
- assert isinstance(mod, torch.nn.Module)
- if not hasattr(mod, '_parameters'):
- raise RuntimeError("'{}' has not been initialized, did you forget to call 'super()'?"
- .format(torch.typename(type(mod))))
- # This is to avoid importing torch.distributed.nn
- if not hasattr(mod, 'remote_parameters'):
- for name, param in mod._parameters.items():
- if param is not None and torch.nn.parameter.is_lazy(param):
- raise RuntimeError("'{}' has uninitialized parameters {}. Did you forget to run a forward pass?"
- .format(torch.typename(type(mod)), name))
- for name, buf in mod._buffers.items():
- if buf is not None and torch.nn.parameter.is_lazy(buf):
- raise RuntimeError("'{}' has uninitialized buffers {}. Did you forget to run a forward pass?"
- .format(torch.typename(type(mod)), name))
- def infer_methods_to_compile(nn_module):
- """
- Implements the default rules for which methods should act as starting
- points for compilation (TODO add a link when the rules are published).
- """
- check_module_initialized(nn_module)
- user_annotated_ignored_attributes = getattr(nn_module, "__jit_ignored_attributes__", list())
- ignored_properties = jit_ignored_properties(nn_module)
- methods: List[str] = []
- if hasattr(nn_module, 'forward') and not _jit_internal.is_ignored_fn(nn_module.forward):
- forward_func = getattr(nn_module.forward, "__func__", None)
- module_forward = getattr(torch.nn.Module, "forward", None)
- if forward_func != module_forward:
- methods = ['forward']
- exported = []
- for name in dir(nn_module):
- if name in ignored_properties:
- continue
- item = getattr(nn_module, name, None)
- if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.EXPORT:
- exported.append(name)
- methods = methods + exported
- overload_name_mappings = dict(getattr(nn_module, "__overloads__", {}))
- overload_info = get_overload_annotations(nn_module, ignored_properties)
- overload_name_mappings.update(get_overload_name_mapping(overload_info))
- overload_stubs = make_stubs_for_overloads(overload_info)
- nn_module.__overloads__ = overload_name_mappings
- # we shouldn't directly compile overloaded methods, just its overloads
- def ignore_overloaded(method_name):
- return method_name not in overload_name_mappings
- filtered_methods = filter(ignore_overloaded, methods)
- # Unique the methods. We don't want to use a set to store the methods because it
- # introduces non-determinism to compile order.
- uniquer: Set[str] = set()
- uniqued_methods = []
- for name in filtered_methods:
- if name in uniquer:
- continue
- uniqued_methods.append(name)
- uniquer.add(name)
- stubs = []
- for method in uniqued_methods:
- stubs.append(make_stub_from_method(nn_module, method))
- return overload_stubs + stubs
- def get_hook_stubs(nn_module):
- """
- Returns forward hook and pre_hook ScriptModuleStubs
- """
- check_module_initialized(nn_module)
- hook_map: Dict = {}
- hook_stubs = []
- for hook in nn_module._forward_hooks.values():
- if hook.__name__ in hook_map:
- if id(hook) != id(hook_map[hook.__name__]):
- raise RuntimeError(
- f"Hook '{hook.__name__}' on {type(nn_module).__name__} "
- "has at least two different python definitions."
- " Please use unique names for all hooks."
- )
- else:
- hook_map[hook.__name__] = hook
- hook_stubs.append(make_stub(hook, hook.__name__))
- pre_hook_stubs = []
- for pre_hook in nn_module._forward_pre_hooks.values():
- if pre_hook.__name__ in hook_map:
- if id(pre_hook) != id(hook_map[pre_hook.__name__]):
- raise RuntimeError(
- f"Pre-hook '{pre_hook.__name__}' on {type(nn_module).__name__} "
- "has at least two different python definitions."
- " Please use unique names for all hooks."
- )
- else:
- hook_map[pre_hook.__name__] = pre_hook
- pre_hook_stubs.append(make_stub(pre_hook, pre_hook.__name__))
- return hook_stubs, pre_hook_stubs
- def get_property_stubs(nn_module):
- """
- Create property stubs for the properties of the module by creating method
- stubs for the getter and setter.
- """
- module_ty = type(nn_module)
- properties_asts = get_class_properties(module_ty, self_name="RecursiveScriptModule")
- rcbs = {}
- for name in dir(module_ty):
- item = getattr(module_ty, name, None)
- if isinstance(item, property):
- if not item.fget:
- raise RuntimeError(f'Property {name} of {nn_module.__name__} must have a getter')
- rcbs[name] = _jit_internal.createResolutionCallbackFromClosure(item.fget)
- stubs = [PropertyStub(rcbs[ast.name().name], ast) for ast in properties_asts]
- return stubs
- def interface_script(mod_interface, nn_module):
- """
- Makes a ScriptModule from an nn.Module, using the interface methods rule for
- determining which methods to compile.
- Args:
- mod_interface: the interface type that the module have
- nn_module: The original Python nn.Module that we are creating a ScriptModule for.
- """
- if isinstance(nn_module, torch.jit.ScriptModule):
- return nn_module
- check_module_initialized(nn_module)
- def infer_interface_methods_to_compile(nn_module):
- """
- Rule to infer the methods from the interface type to know which
- methods need to act as starting points for compilation.
- """
- stubs = []
- for method in mod_interface.getMethodNames():
- stubs.append(make_stub_from_method(nn_module, method))
- return stubs
- return create_script_module(nn_module, infer_interface_methods_to_compile)
- def try_compile_fn(fn, loc):
- if _jit_internal.is_ignored_fn(fn):
- # Don't do anything for @ignore'd functions
- return None
- if isinstance(fn, torch.nn.Module):
- # Since modules are callable pybind recognizes them as functions, but
- # don't do anything for them
- return None
- if not inspect.isfunction(fn) and not inspect.ismethod(fn):
- raise RuntimeError("`{}` is not a function. Recursive scripting only supports "
- "Python functions or methods currently.\n"
- "Consider manually annotating `{}` with @torch.jit.script.".format(fn, fn))
- # We don't have the actual scope where the function was defined, but we can
- # extract the necessary info from the closed over variables on the function
- # object
- rcb = _jit_internal.createResolutionCallbackFromClosure(fn)
- return torch.jit.script(fn, _rcb=rcb)
- def wrap_cpp_class(cpp_class):
- """
- Wrap this torch._C.Object in a Python RecursiveScriptClass.
- """
- return torch.jit.RecursiveScriptClass(cpp_class)
- def wrap_cpp_module(cpp_module):
- """
- Wrap this torch._C.ScriptModule in a Python ScriptModule, recursively for all submodules
- """
- def init_fn(script_module):
- for name, cpp_module in torch._C.ModuleDict(script_module._c).items():
- setattr(script_module, name, wrap_cpp_module(cpp_module))
- script_module._concrete_type = torch._C.ConcreteModuleType.from_jit_type(script_module._c._type())
- for idx, fn in enumerate(script_module._c._get_forward_pre_hooks()):
- script_module._forward_pre_hooks[idx] = fn
- for idx, fn in enumerate(script_module._c._get_forward_hooks()):
- script_module._forward_hooks[idx] = fn
- return torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
- def compile_unbound_method(concrete_type, fn):
- if _jit_internal.is_ignored_fn(fn):
- return None
- stub = make_stub(fn, fn.__name__)
- with torch._jit_internal._disable_emit_hooks():
- # We don't want to call the hooks here since the graph that is calling
- # this function is not yet complete
- create_methods_and_properties_from_stubs(concrete_type, (stub,), ())
- return stub
- def lazy_bind(concrete_type, unbound_method):
- """
- Returns a function that lazily binds `unbound_method` to a provided
- Module IValue, then invokes the method. We do this so that any Python
- shenanigans that will poison type sharing are impossible at compile
- time.
- """
- def lazy_binding_method(cpp_module, *args):
- def init_fn(script_module):
- orig_class = concrete_type.py_class
- # Copy @ignored/@unused methods from the original module to the new one.
- # This ensures they are available during execution.
- for name in dir(orig_class):
- item = getattr(orig_class, name, None)
- if _jit_internal.is_ignored_fn(item):
- setattr(script_module, name, item)
- # Copy constants over so they are available during execution.
- for name, value in concrete_type.get_constants().items():
- setattr(script_module, name, value)
- script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
- method = types.MethodType(unbound_method, script_module)
- return method(*args)
- # make the lazy binding method "look like" the original method
- lazy_binding_method.original_fn = unbound_method # type: ignore[attr-defined]
- lazy_binding_method.__name__ = unbound_method.__name__
- torch._jit_internal.copy_torchscript_modifier(unbound_method, lazy_binding_method)
- return lazy_binding_method
|