123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269 |
- """Tracing
- This module contains functionality to support the JIT's tracing frontend, notably:
- * torch.jit.trace
- * torch.jit.trace_module
- This is not intended to be imported directly; please use the exposed
- functionalities in `torch.jit`.
- """
- import torch
- import copy
- import os
- import contextlib
- import functools
- import warnings
- import inspect
- import re
- from typing import Any, Callable, Dict, List, Optional, Set
- from torch.jit._state import _python_cu, _enabled
- from torch.jit._script import ScriptModule, _CachedForward, script
- from torch._jit_internal import _qualified_name, is_scripting, get_callable_argument_names
- from torch.autograd import function
- from torch.nn import Module
- from torch.testing._comparison import default_tolerances
- _flatten = torch._C._jit_flatten
- _unflatten = torch._C._jit_unflatten
- def _create_interpreter_name_lookup_fn(frames_up=1):
- def _get_interpreter_name_for_var(var):
- frame = inspect.currentframe()
- if not frame:
- raise RuntimeError("failed to inspect frame")
- i = 0
- while i < frames_up + 1:
- frame = frame.f_back
- if not frame:
- raise RuntimeError("failed to get frame")
- i += 1
- f_locals = frame.f_locals
- f_globals = frame.f_globals
- for k, v in f_locals.items():
- if isinstance(v, torch.Tensor) and var is v:
- return k if k != "self" else ""
- return ""
- return _get_interpreter_name_for_var
- def _unique_state_dict(module, keep_vars=False):
- # since Parameter.detach() always creates a new torch.Tensor instance,
- # id(v) doesn't work with it. So we always get the Parameter or Buffer
- # as values, and deduplicate the params using Parameters and Buffers
- state_dict = module.state_dict(keep_vars=True)
- filtered_dict = type(state_dict)()
- seen_ids: Set[int] = set()
- for k, v in state_dict.items():
- if id(v) in seen_ids:
- continue
- seen_ids.add(id(v))
- if keep_vars:
- filtered_dict[k] = v
- else:
- filtered_dict[k] = v.detach()
- return filtered_dict
- class ONNXTracedModule(torch.nn.Module):
- def __init__(
- self,
- inner,
- strict=True,
- force_outplace=False,
- return_inputs=False,
- return_inputs_states=False,
- ):
- super().__init__()
- # inner may be a Module, or it may be an arbitrary callable
- # If it's a Module, we get its parameters automatically, which lets
- # us avoid a special casing functions versus modules.
- self.inner = inner
- self.strict = strict
- self._force_outplace = force_outplace
- self._return_inputs = return_inputs
- self._return_inputs_states = return_inputs_states
- def forward(self, *args: torch.Tensor):
- in_vars, in_desc = _flatten(args)
- # NOTE: use full state, because we need it for BatchNorm export
- # This differs from the compiler path, which doesn't support it at the moment.
- module_state = list(_unique_state_dict(self, keep_vars=True).values())
- ret_inputs = []
- inputs_states = []
- outs = []
- def wrapper(*args):
- in_args: List[torch.Tensor] = []
- for i in range(len(in_vars)):
- if not isinstance(args[i], torch.Tensor):
- raise RuntimeError('Expected Tensor argument')
- in_args.append(args[i])
- trace_inputs = _unflatten(in_args, in_desc)
- ret_inputs.append(
- tuple(x.clone(memory_format=torch.preserve_format) for x in args)
- )
- if self._return_inputs_states:
- inputs_states.append(_unflatten(in_args, in_desc))
- outs.append(self.inner(*trace_inputs))
- if self._return_inputs_states:
- inputs_states[0] = (inputs_states[0], trace_inputs)
- out_vars, _ = _flatten(outs)
- if len(out_vars) == 1:
- return out_vars[0]
- else:
- return tuple(out_vars)
- graph, out = torch._C._create_graph_by_tracing(
- wrapper,
- in_vars + module_state,
- _create_interpreter_name_lookup_fn(),
- self.strict,
- self._force_outplace,
- )
- if self._return_inputs:
- return graph, outs[0], ret_inputs[0]
- if self._return_inputs_states:
- return graph, outs[0], inputs_states[0]
- else:
- return graph, outs[0]
- def _clone_inputs(args):
- def clone_input(a):
- if a is None:
- return None
- elif isinstance(a, torch.Tensor):
- # TODO: figure out one liner to .clone() and set requires_grad
- v = (
- a.detach()
- .clone(memory_format=None if a.is_mkldnn else torch.preserve_format)
- .requires_grad_(a.requires_grad)
- )
- if a.grad is not None:
- v.grad = clone_input(v.grad)
- return v
- else:
- return a.clone(memory_format=torch.preserve_format)
- return function._nested_map(
- lambda x: isinstance(x, torch.Tensor), clone_input, condition_msg="tensors"
- )(args)
- # This is purely for developer debugging. We are not going to advertise it.
- _JIT_TIME = os.environ.get("PYTORCH_JIT_TIME", False) # CUDA-only timing
- _JIT_DISABLE = os.environ.get("PYTORCH_JIT_DISABLE", False)
- _JIT_STATS = os.environ.get("PYTORCH_JIT_STATS", False)
- @contextlib.contextmanager
- def _time(trace_name, name, time=True):
- if (not _JIT_TIME and not time) or not torch.cuda.is_available():
- yield
- return
- stream = torch.cuda.current_stream()
- start = torch.cuda.Event(enable_timing=True)
- end = torch.cuda.Event(enable_timing=True)
- stream.record_event(start)
- try:
- yield
- finally:
- stream.record_event(end)
- end.synchronize()
- print("{} {} time: {} ms".format(trace_name, name, start.elapsed_time(end)))
- def verify(model, args, loss_fn=torch.sum, devices=None):
- """
- Verify that a JIT compiled model has the same behavior as its uncompiled
- version along with its backwards pass. If your model returns multiple
- outputs, you must also specify a `loss_fn` to produce a loss for which
- the backwards will be computed.
- This function has side-effects (e.g., it executes your model / saves and loads
- parameters), so don't expect the model to come out exactly the same as what
- you passed in.
- Args:
- model (compiled torch.nn.Module or function): the module/function to be
- verified. The module/function definition MUST have been decorated with
- `@torch.jit.compile`.
- args (tuple or Tensor): the positional arguments to pass to the
- compiled function/module to be verified. A non-tuple is assumed to
- be a single positional argument to be passed to the model.
- loss_fn (function, optional): the loss function to be applied to
- the output of the model, before backwards is invoked. By default,
- we assume that a model returns a single result, and we :func:`torch.sum`
- before calling backwards; if this is inappropriate, you can pass your
- own loss function. Note that if a model returns a tuple of results,
- these are passed as separate positional arguments to `loss_fn`.
- devices (iterable of device IDs, optional): the GPU devices which the
- compiled module will be run on. This determines the RNG state we
- must save when running both compiled and uncompiled versions of the model.
- """
- # TODO: In principle, we track device information in our trace, so it
- # should be possible to check if our execution actually obeyed the 'devices'
- # the user provided.
- # TODO: Consider adding a utility function to torch.jit to test
- # for this case
- if not isinstance(model, torch._C.CompiledFunction): # type: ignore[attr-defined]
- raise TypeError(
- "Cannot verify an uncompiled module. Add @torch.jit.compile to compile it"
- )
- is_module = isinstance(model, Module)
- if not isinstance(args, tuple):
- args = (args,)
- saved_args = _clone_inputs(args)
- if is_module:
- saved_state = copy.deepcopy(model.state_dict())
- def run_fwd_bwd(args, force_trace=False, assert_compiled=False):
- params = list(model.parameters()) if is_module else []
- in_vars, _ = _flatten((args, params))
- # We use a special API to reset the trace and compile it from scratch.
- compiled_fn = model
- if force_trace:
- compiled_fn.clear_cache()
- if assert_compiled:
- hits = compiled_fn.hits
- out = model(*args)
- if assert_compiled and compiled_fn.hits == hits:
- raise RuntimeError("failed to use the compiled function")
- if not isinstance(out, tuple):
- out = (out,)
- if loss_fn == torch.sum and len(out) != 1:
- raise ValueError(
- (
- "Model returns {} outputs, but default loss function "
- "(torch.sum) can only handle a single output"
- ).format(len(out))
- )
- out_vars, _ = _flatten(out)
- saved_outs = [
- v.detach().clone(memory_format=torch.preserve_format) for v in out_vars
- ]
- loss = loss_fn(*out)
- grads = torch.autograd.grad([loss], in_vars)
- # TODO: I'm not sure if the clone here is necessary but it is safer
- saved_grads = [
- v.detach().clone(memory_format=torch.preserve_format) for v in grads
- ]
- return (saved_outs, saved_grads)
- with torch.random.fork_rng(devices, _caller="torch.jit.verify"):
- uncompiled_outs, uncompiled_grads = run_fwd_bwd(args, force_trace=True)
- assert model.has_trace_for(*args)
- if is_module:
- model.load_state_dict(saved_state)
- compiled_outs, compiled_grads = run_fwd_bwd(args, assert_compiled=True)
- _verify_equal(uncompiled_outs, compiled_outs)
- _verify_equal(uncompiled_grads, compiled_grads)
- def _verify_equal(xs, ys):
- for x, y in zip(xs, ys):
- if x.sub(y).abs().max() > 1e-6:
- raise RuntimeError("JIT and real computation mismatch")
- def indent(s):
- return "\n".join(["\t" + line for line in s.splitlines()])
- class TracingCheckError(Exception):
- def __init__(self, graph_diff_error, tensor_compare_error, extra_msg=None):
- self.message = "Tracing failed sanity checks!\n"
- if extra_msg is not None:
- self.message += extra_msg + "\n"
- if graph_diff_error is not None:
- self.message += "ERROR: Graphs differed across invocations!\n"
- self.message += indent(graph_diff_error) + "\n"
- if tensor_compare_error is not None:
- self.message += (
- "ERROR: Tensor-valued Constant nodes differed in value "
- "across invocations. This often indicates that the tracer has"
- " encountered untraceable code.\n"
- )
- self.message += indent(tensor_compare_error) + "\n"
- super().__init__(self.message)
- # Check the traced module against a set of user-provided validation inputs
- @torch.no_grad()
- def _check_trace(
- check_inputs,
- func,
- traced_func,
- check_tolerance,
- strict,
- force_outplace,
- is_trace_module,
- _module_class,
- example_inputs_is_kwarg=False,
- ):
- # Note: tracing is independent of optimizations, which consume the trace
- for inputs in check_inputs:
- if isinstance(inputs, torch.Tensor):
- inputs = (inputs,)
- if is_trace_module:
- copied_dict = {}
- for name, data in inputs.items():
- copied_dict[name] = _clone_inputs(data)
- check_mod = torch.jit.trace_module(
- func.__self__ if hasattr(func, "__self__") else func,
- copied_dict,
- check_trace=False,
- strict=strict,
- _force_outplace=force_outplace,
- _module_class=_module_class,
- _compilation_unit=torch._C.CompilationUnit(),
- example_inputs_is_kwarg=example_inputs_is_kwarg,
- _store_inputs=False
- )
- check_mod_func = check_mod._c._get_method(traced_func.name)
- inputs = inputs[traced_func.name]
- if isinstance(inputs, (torch.Tensor)) or isinstance(inputs, dict) and not example_inputs_is_kwarg:
- inputs = (inputs,)
- else:
- if example_inputs_is_kwarg:
- check_mod = torch.jit.trace(
- func,
- check_trace=False,
- strict=strict,
- _force_outplace=force_outplace,
- _module_class=_module_class,
- example_kwarg_inputs=_clone_inputs(inputs),
- _store_inputs=False
- )
- else:
- check_mod = torch.jit.trace(
- func,
- _clone_inputs(inputs),
- check_trace=False,
- strict=strict,
- _force_outplace=force_outplace,
- _module_class=_module_class,
- _store_inputs=False
- )
- check_mod_func = check_mod
- def graph_diagnostic_info():
- mod_canonicalized = torch._C._jit_pass_canonicalize(traced_func.graph)
- torch._C._jit_pass_inline(mod_canonicalized)
- torch._C._jit_pass_erase_shape_information(mod_canonicalized)
- mod_str = str(mod_canonicalized)
- mod_str = re.sub(r"___torch_mangle_[0-9]+\.", "", mod_str)
- check_canonicalized = torch._C._jit_pass_canonicalize(check_mod_func.graph)
- torch._C._jit_pass_inline(check_canonicalized)
- torch._C._jit_pass_erase_shape_information(check_canonicalized)
- check_str = str(check_canonicalized)
- check_str = re.sub(r"___torch_mangle_[0-9]+\.", "", check_str)
- graph_diff_errors = None
- if mod_str != check_str:
- import difflib
- graph_diff = difflib.ndiff(
- mod_str.splitlines(True), check_str.splitlines(True)
- )
- graph_diff_errors = "Graph diff:\n" + indent("".join(graph_diff)) + "\n"
- for n_mod, n_check in zip(
- mod_canonicalized.nodes(), check_canonicalized.nodes()
- ):
- if str(n_mod) != str(n_check):
- graph_diff_errors += "First diverging operator:\n"
- node_diff = difflib.ndiff(
- str(n_mod).splitlines(True), str(n_check).splitlines(True)
- )
- source_printout = (
- "Node diff:\n" + indent("".join(node_diff)) + "\n"
- )
- mod_stack = n_mod.sourceRange()
- if mod_stack:
- source_printout += (
- "Trace source location:\n" + indent(mod_stack) + "\n"
- )
- check_stack = n_check.sourceRange()
- if check_stack:
- source_printout += (
- "Check source location:\n" + indent(check_stack) + "\n"
- )
- graph_diff_errors += source_printout
- break # For now, only print out the first pair of nodes that diverges
- tensor_compare_errors = None
- # Check Tensor-valued constant nodes
- for n_mod, n_check in zip(
- mod_canonicalized.nodes(), check_canonicalized.nodes()
- ):
- if n_mod.kind() != n_check.kind():
- break # Graphs have already diverged
- if n_mod.kind() == "prim::Constant" and not (
- n_mod.mustBeNone() or n_check.mustBeNone()
- ):
- if not n_mod.hasAttribute("value"):
- continue
- if n_mod.kindOf("value") != "t" or n_check.kindOf("value") != "t":
- continue
- mod_tensor_val = n_mod.t("value")
- check_tensor_val = n_check.t("value")
- try:
- torch.testing.assert_close(mod_tensor_val, check_tensor_val, equal_nan=True)
- except (RuntimeError, AssertionError) as e:
- if tensor_compare_errors is None:
- tensor_compare_errors = ""
- tensor_compare_errors += "Node:\n" + indent(str(n_mod)) + "\n"
- compare_stack = n_mod.sourceRange()
- if compare_stack:
- tensor_compare_errors += (
- "Source Location:\n" + indent(compare_stack) + "\n"
- )
- tensor_compare_errors += "Comparison exception: " + indent(
- str(e)
- )
- break # For now, only print the first diverging pair
- return graph_diff_errors, tensor_compare_errors
- def wrap_retval(x):
- return x if isinstance(x, tuple) else (x,)
- def run_mod_and_filter_tensor_outputs(mod, inputs, running_what):
- try:
- if isinstance(inputs, dict) and example_inputs_is_kwarg:
- outs = wrap_retval(mod(**inputs))
- else:
- outs = wrap_retval(mod(*_clone_inputs(inputs)))
- outs = [out for out in outs if isinstance(out, torch.Tensor)]
- return outs
- except Exception as e:
- graph_diff_errors, tensor_compare_errors = graph_diagnostic_info()
- msg = f"encountered an exception while running the {running_what} with test inputs.\nException:\n{indent(str(e))}"
- raise TracingCheckError(
- graph_diff_errors,
- tensor_compare_errors,
- extra_msg=msg,
- ) from e
- has_warned = [False]
- def maybe_warn_nondeterministic():
- if has_warned[0]:
- return
- has_warned[0] = True
- nondeterm_ops = [
- op for op in traced_func.graph.nodes() if op.isNondeterministic()
- ]
- if len(nondeterm_ops) > 0:
- nondeterministic_ops_warning = "Trace had nondeterministic nodes. "
- nondeterministic_ops_warning += (
- "Did you forget call .eval() on your model? Nodes:\n"
- )
- nondeterministic_ops_warning += "\n".join(
- [indent(str(op)) for op in nondeterm_ops][:20]
- )
- nondeterministic_ops_warning += (
- "\nThis may cause errors in trace checking. To disable trace checking,"
- " pass check_trace=False to torch.jit.trace()"
- )
- warnings.warn(
- nondeterministic_ops_warning, category=TracerWarning, stacklevel=5
- )
- def compare_outputs(original, reference, match_what):
- all_ok = True
- for i, (orig, ref) in enumerate(zip(original, reference)):
- try:
- if orig.is_quantized:
- orig = orig.dequantize()
- if ref.is_quantized:
- ref = ref.dequantize()
- if orig.is_mkldnn:
- orig = orig.to_dense()
- if ref.is_mkldnn:
- ref = ref.to_dense()
- if ref.is_complex() or orig.is_complex():
- torch.testing.assert_close(
- orig.to(torch.cdouble),
- ref.to(torch.cdouble),
- rtol=check_tolerance,
- atol=default_tolerances(orig, ref)[1],
- equal_nan=True,
- )
- else:
- if orig.is_mps or ref.is_mps:
- torch.testing.assert_close(
- orig.float(),
- ref.float(),
- rtol=check_tolerance,
- atol=default_tolerances(orig, ref)[1],
- equal_nan=True,
- )
- else:
- torch.testing.assert_close(
- orig.double(),
- ref.double(),
- rtol=check_tolerance,
- atol=default_tolerances(orig, ref)[1],
- equal_nan=True,
- )
- except AssertionError as e:
- maybe_warn_nondeterministic()
- warnings.warn(
- "Output nr "
- + str(i + 1)
- + ". of the traced function does not match "
- "the corresponding output of the "
- + match_what
- + ". Detailed error:\n"
- + str(e),
- category=TracerWarning,
- stacklevel=4,
- )
- all_ok = False
- return all_ok
- traced_outs = run_mod_and_filter_tensor_outputs(traced_func, inputs, "trace")
- fn_outs = run_mod_and_filter_tensor_outputs(func, inputs, "Python function")
- if compare_outputs(traced_outs, fn_outs, "Python function"):
- check_outs = run_mod_and_filter_tensor_outputs(
- check_mod_func, inputs, "repeated trace"
- )
- compare_outputs(traced_outs, check_outs, "repeated trace")
- diag_info = graph_diagnostic_info()
- if any(info is not None for info in diag_info):
- raise TracingCheckError(*diag_info)
- class TracerWarning(Warning):
- @staticmethod
- def ignore_lib_warnings():
- # We ignore warnings from all submodules excluding the JIT, because we need them e.g. for _check_trace
- warnings.filterwarnings(
- "ignore", category=TracerWarning, module="torch.(?!jit)"
- )
- # We ignore the tracer warnings coming form inside the library, because all our shape
- # checks in nn will trigger them.
- TracerWarning.ignore_lib_warnings()
- torch._C._tracer_warn_use_python()
- def make_tuple(example_inputs):
- if isinstance(example_inputs, (torch.Tensor, dict)):
- return (example_inputs,)
- # done primarily so that weird iterables fail here and not pybind11 code
- if not isinstance(example_inputs, tuple):
- return tuple(example_inputs)
- return example_inputs
- def make_module(mod, _module_class, _compilation_unit):
- if isinstance(mod, ScriptModule):
- return mod
- elif torch._jit_internal.module_has_exports(mod):
- infer_methods_stubs_fn = torch.jit._recursive.make_stubs_from_exported_methods
- return torch.jit._recursive.create_script_module(
- mod,
- infer_methods_stubs_fn,
- share_types=False,
- is_tracing=True
- )
- else:
- if _module_class is None:
- _module_class = TopLevelTracedModule
- return _module_class(mod, _compilation_unit=_compilation_unit)
- def wrap_check_inputs(check_inputs):
- if check_inputs is None:
- return None
- return [{"forward": c} for c in check_inputs]
- def trace(
- func,
- example_inputs=None,
- optimize=None,
- check_trace=True,
- check_inputs=None,
- check_tolerance=1e-5,
- strict=True,
- _force_outplace=False,
- _module_class=None,
- _compilation_unit=_python_cu,
- example_kwarg_inputs=None,
- _store_inputs=True
- ):
- """
- Trace a function and return an executable or :class:`ScriptFunction`
- that will be optimized using just-in-time compilation. Tracing is ideal for
- code that operates only on ``Tensor``\\s and lists, dictionaries, and
- tuples of ``Tensor``\\s.
- Using `torch.jit.trace` and `torch.jit.trace_module`, you can turn an
- existing module or Python function into a TorchScript
- :class:`ScriptFunction` or :class:`ScriptModule`. You must provide example
- inputs, and we run the function, recording the operations performed on all
- the tensors.
- * The resulting recording of a standalone function produces `ScriptFunction`.
- * The resulting recording of `nn.Module.forward` or `nn.Module` produces
- `ScriptModule`.
- This module also contains any parameters that the original
- module had as well.
- Warning:
- Tracing only correctly records functions and modules which are not data
- dependent (e.g., do not have conditionals on data in tensors) and do not have
- any untracked external dependencies (e.g., perform input/output or
- access global variables). Tracing only records operations done when the given
- function is run on the given tensors. Therefore, the returned
- `ScriptModule` will always run the same traced graph on any input. This
- has some important implications when your module is expected to run
- different sets of operations, depending on the input and/or the module
- state. For example,
- * Tracing will not record any control-flow like if-statements or loops.
- When this control-flow is constant across your module, this is fine
- and it often inlines the control-flow decisions. But sometimes the
- control-flow is actually part of the model itself. For instance, a
- recurrent network is a loop over the (possibly dynamic) length of an
- input sequence.
- * In the returned :class:`ScriptModule`, operations that have different
- behaviors in ``training`` and ``eval`` modes will always behave as if
- it is in the mode it was in during tracing, no matter which mode the
- `ScriptModule` is in.
- In cases like these, tracing would not be appropriate and
- :func:`scripting <torch.jit.script>` is a better choice. If you trace
- such models, you may silently get incorrect results on subsequent
- invocations of the model. The tracer will try to emit warnings when
- doing something that may cause an incorrect trace to be produced.
- Args:
- func (callable or torch.nn.Module): A Python function or `torch.nn.Module`
- that will be run with `example_inputs`. `func` arguments and return
- values must be tensors or (possibly nested) tuples that contain
- tensors. When a module is passed `torch.jit.trace`, only the
- ``forward`` method is run and traced (see :func:`torch.jit.trace
- <torch.jit.trace_module>` for details).
- Keyword arguments:
- example_inputs (tuple or torch.Tensor or None, optional): A tuple of example
- inputs that will be passed to the function while tracing.
- Default: ``None``. Either this argument or ``example_kwarg_inputs``
- should be specified. The resulting trace can be run with inputs of
- different types and shapes assuming the traced operations support those
- types and shapes. `example_inputs` may also be a single Tensor in which
- case it is automatically wrapped in a tuple. When the value is None,
- ``example_kwarg_inputs`` should be specified.
- check_trace (``bool``, optional): Check if the same inputs run through
- traced code produce the same outputs. Default: ``True``. You might want
- to disable this if, for example, your network contains non-
- deterministic ops or if you are sure that the network is correct despite
- a checker failure.
- check_inputs (list of tuples, optional): A list of tuples of input
- arguments that should be used to check the trace against what is
- expected. Each tuple is equivalent to a set of input arguments that
- would be specified in ``example_inputs``. For best results, pass in
- a set of checking inputs representative of the space of shapes and
- types of inputs you expect the network to see. If not specified,
- the original ``example_inputs`` are used for checking
- check_tolerance (float, optional): Floating-point comparison tolerance
- to use in the checker procedure. This can be used to relax the
- checker strictness in the event that results diverge numerically
- for a known reason, such as operator fusion.
- strict (``bool``, optional): run the tracer in a strict mode or not
- (default: ``True``). Only turn this off when you want the tracer to
- record your mutable container types (currently ``list``/``dict``)
- and you are sure that the container you are using in your
- problem is a ``constant`` structure and does not get used as
- control flow (if, for) conditions.
- example_kwarg_inputs (dict, optional): This parameter is a pack of keyword
- arguments of example inputs that will be passed to the function while
- tracing. Default: ``None``. Either this argument or ``example_inputs``
- should be specified. The dict will be unpacking by the arguments name
- of the traced function. If the keys of the dict don't not match with
- the traced function's arguments name, a runtime exception will be raised.
- Returns:
- If `func` is `nn.Module` or ``forward`` of `nn.Module`, `trace` returns
- a :class:`ScriptModule` object with a single ``forward`` method
- containing the traced code. The returned `ScriptModule` will
- have the same set of sub-modules and parameters as the original
- ``nn.Module``. If ``func`` is a standalone function, ``trace``
- returns `ScriptFunction`.
- Example (tracing a function):
- .. testcode::
- import torch
- def foo(x, y):
- return 2 * x + y
- # Run `foo` with the provided inputs and record the tensor operations
- traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
- # `traced_foo` can now be run with the TorchScript interpreter or saved
- # and loaded in a Python-free environment
- Example (tracing an existing module)::
- import torch
- import torch.nn as nn
- class Net(nn.Module):
- def __init__(self):
- super().__init__()
- self.conv = nn.Conv2d(1, 1, 3)
- def forward(self, x):
- return self.conv(x)
- n = Net()
- example_weight = torch.rand(1, 1, 3, 3)
- example_forward_input = torch.rand(1, 1, 3, 3)
- # Trace a specific method and construct `ScriptModule` with
- # a single `forward` method
- module = torch.jit.trace(n.forward, example_forward_input)
- # Trace a module (implicitly traces `forward`) and construct a
- # `ScriptModule` with a single `forward` method
- module = torch.jit.trace(n, example_forward_input)
- """
- if not _enabled:
- return func
- if optimize is not None:
- warnings.warn(
- "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead"
- )
- if isinstance(func, torch.jit.ScriptModule):
- # it is hard to trace it because the forward method on ScriptModule is already defined, so it
- # would result in an error.
- warnings.warn(
- "The input to trace is already a ScriptModule, tracing it is a no-op. Returning the object as is."
- )
- return func
- if isinstance(func, torch.nn.Module):
- if example_inputs is None:
- if isinstance(example_kwarg_inputs, dict):
- example_inputs = example_kwarg_inputs
- else:
- raise RuntimeError("example_kwarg_inputs should be a dict")
- return trace_module(
- func,
- {"forward": example_inputs},
- None,
- check_trace,
- wrap_check_inputs(check_inputs),
- check_tolerance,
- strict,
- _force_outplace,
- _module_class,
- example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
- _store_inputs=_store_inputs
- )
- if (
- hasattr(func, "__self__")
- and isinstance(func.__self__, torch.nn.Module)
- and func.__name__ == "forward"
- ):
- if example_inputs is None:
- if isinstance(example_kwarg_inputs, dict):
- example_inputs = example_kwarg_inputs
- else:
- raise RuntimeError("example_kwarg_inputs should be a dict")
- return trace_module(
- func.__self__,
- {"forward": example_inputs},
- None,
- check_trace,
- wrap_check_inputs(check_inputs),
- check_tolerance,
- strict,
- _force_outplace,
- _module_class,
- example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
- _store_inputs=_store_inputs
- )
- # Special case for common case of passing a single Tensor
- if isinstance(example_inputs, (torch.Tensor, dict)) and example_kwarg_inputs is None:
- example_inputs = (example_inputs,)
- # done primarily so that weird iterables fail here and not pybind11 code
- elif example_kwarg_inputs is None and not isinstance(example_inputs, tuple):
- example_inputs = tuple(example_inputs)
- var_lookup_fn = _create_interpreter_name_lookup_fn(0)
- if hasattr(func, "__self__") and isinstance(func.__self__, torch.nn.Module):
- raise AttributeError(
- "trace doesn't support compiling individual module's functions.\n"
- "Please use trace_module"
- )
- name = _qualified_name(func)
- if isinstance(example_kwarg_inputs, dict):
- example_inputs = example_kwarg_inputs
- traced = torch._C._create_function_from_trace_with_dict(
- name,
- func,
- example_kwarg_inputs,
- var_lookup_fn,
- strict,
- _force_outplace,
- get_callable_argument_names(func)
- )
- else:
- traced = torch._C._create_function_from_trace(
- name,
- func,
- example_inputs,
- var_lookup_fn,
- strict,
- _force_outplace,
- get_callable_argument_names(func)
- )
- # Check the trace against new traces created from user-specified inputs
- if check_trace:
- if check_inputs is not None:
- _check_trace(
- check_inputs,
- func,
- traced,
- check_tolerance,
- strict,
- _force_outplace,
- False,
- _module_class,
- example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
- )
- else:
- _check_trace(
- [example_inputs],
- func,
- traced,
- check_tolerance,
- strict,
- _force_outplace,
- False,
- _module_class,
- example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
- )
- # Allow torch.compile() to inline
- traced._torchdynamo_inline = func # type: ignore[attr-defined]
- return traced
- _trace_module_map: Optional[Dict[Any, Any]] = None
- def trace_module(
- mod,
- inputs,
- optimize=None,
- check_trace=True,
- check_inputs=None,
- check_tolerance=1e-5,
- strict=True,
- _force_outplace=False,
- _module_class=None,
- _compilation_unit=_python_cu,
- example_inputs_is_kwarg=False,
- _store_inputs=True,
- ):
- """
- Trace a module and return an executable :class:`ScriptModule` that will be optimized
- using just-in-time compilation. When a module is passed to :func:`torch.jit.trace <torch.jit.trace>`, only
- the ``forward`` method is run and traced. With ``trace_module``, you can specify a dictionary of
- method names to example inputs to trace (see the ``inputs``) argument below.
- See :func:`torch.jit.trace <torch.jit.trace>` for more information on tracing.
- Args:
- mod (torch.nn.Module): A ``torch.nn.Module`` containing methods whose names are
- specified in ``inputs``. The given methods will be compiled
- as a part of a single `ScriptModule`.
- inputs (dict): A dict containing sample inputs indexed by method names in ``mod``.
- The inputs will be passed to methods whose names correspond to inputs'
- keys while tracing.
- ``{ 'forward' : example_forward_input, 'method2': example_method2_input}``
- Keyword arguments:
- check_trace (``bool``, optional): Check if the same inputs run through
- traced code produce the same outputs. Default: ``True``. You might want
- to disable this if, for example, your network contains non-
- deterministic ops or if you are sure that the network is correct despite
- a checker failure.
- check_inputs (list of dicts, optional): A list of dicts of input arguments that should be used
- to check the trace against what is expected. Each tuple
- is equivalent to a set of input arguments that would
- be specified in ``inputs``. For best results, pass in a
- set of checking inputs representative of the space of
- shapes and types of inputs you expect the network to see.
- If not specified, the original ``inputs`` are used for checking
- check_tolerance (float, optional): Floating-point comparison tolerance to use in the checker procedure.
- This can be used to relax the checker strictness in the event that
- results diverge numerically for a known reason, such as operator fusion.
- example_inputs_is_kwarg (``bool``, optional): This parameter indicate whether the example inputs is a pack
- pack of keyword arguments. Default: ``False``.
- Returns:
- A :class:`ScriptModule` object with a single ``forward`` method containing the traced code.
- When ``func`` is a ``torch.nn.Module``, the returned :class:`ScriptModule` will have the same set of
- sub-modules and parameters as ``func``.
- Example (tracing a module with multiple methods)::
- import torch
- import torch.nn as nn
- class Net(nn.Module):
- def __init__(self):
- super().__init__()
- self.conv = nn.Conv2d(1, 1, 3)
- def forward(self, x):
- return self.conv(x)
- def weighted_kernel_sum(self, weight):
- return weight * self.conv.weight
- n = Net()
- example_weight = torch.rand(1, 1, 3, 3)
- example_forward_input = torch.rand(1, 1, 3, 3)
- # Trace a specific method and construct `ScriptModule` with
- # a single `forward` method
- module = torch.jit.trace(n.forward, example_forward_input)
- # Trace a module (implicitly traces `forward`) and construct a
- # `ScriptModule` with a single `forward` method
- module = torch.jit.trace(n, example_forward_input)
- # Trace specific methods on a module (specified in `inputs`), constructs
- # a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
- inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight}
- module = torch.jit.trace_module(n, inputs)
- """
- if not _enabled:
- return mod
- if optimize is not None:
- warnings.warn(
- "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead"
- )
- var_lookup_fn = _create_interpreter_name_lookup_fn(0)
- if not isinstance(mod, torch.nn.Module):
- raise AttributeError("expected torch.nn.Module as the first argument")
- if not isinstance(inputs, dict):
- raise AttributeError("expected a dictionary of (method_name, input) pairs")
- old_module_map = torch.jit._trace._trace_module_map
- try:
- trace_module_map: Dict[Any, Any] = {}
- def register_submods(mod, prefix):
- for name, child in mod.named_children():
- submod_qualname = prefix + "." + name
- trace_module_map[child] = submod_qualname
- register_submods(child, submod_qualname)
- trace_module_map["__module"] = mod
- torch.jit._trace._trace_module_map = trace_module_map
- register_submods(mod, "__module")
- module = make_module(mod, _module_class, _compilation_unit)
- for method_name, example_inputs in inputs.items():
- if method_name == "forward":
- # "forward" is a special case because we need to trace
- # `Module.__call__`, which sets up some extra tracing, but uses
- # argument names of the real `Module.forward` method.
- func = mod
- forward_method = getattr(mod, method_name)
- argument_names = get_callable_argument_names(forward_method)
- else:
- func = getattr(mod, method_name)
- argument_names = get_callable_argument_names(func)
- if isinstance(example_inputs, dict) and example_inputs_is_kwarg:
- # Raise exception when the user provided key names are not aligned with forward() method's arguments' name/
- for key in example_inputs:
- if key not in argument_names:
- valid_arguments = "[" + ','.join(argument_names) + "]"
- raise NameError("""'{}' is not in forward() method's arguments,
- valid arguments name are {}""".format(key, valid_arguments))
- module._c._create_method_from_trace_with_dict(
- method_name,
- func,
- example_inputs,
- var_lookup_fn,
- strict,
- _force_outplace,
- argument_names,
- _store_inputs
- )
- else:
- example_inputs = make_tuple(example_inputs)
- module._c._create_method_from_trace(
- method_name,
- func,
- example_inputs,
- var_lookup_fn,
- strict,
- _force_outplace,
- argument_names,
- _store_inputs
- )
- check_trace_method = module._c._get_method(method_name)
- # Check the trace against new traces created from user-specified inputs
- if check_trace:
- if check_inputs is not None:
- _check_trace(
- check_inputs,
- func,
- check_trace_method,
- check_tolerance,
- strict,
- _force_outplace,
- True,
- _module_class,
- example_inputs_is_kwarg=example_inputs_is_kwarg,
- )
- else:
- _check_trace(
- [inputs],
- func,
- check_trace_method,
- check_tolerance,
- strict,
- _force_outplace,
- True,
- _module_class,
- example_inputs_is_kwarg=example_inputs_is_kwarg,
- )
- finally:
- torch.jit._trace._trace_module_map = old_module_map
- return module
- def is_tracing():
- """
- Returns ``True`` in tracing (if a function is called during the tracing of
- code with ``torch.jit.trace``) and ``False`` otherwise.
- """
- if is_scripting():
- return False
- return torch._C._is_tracing()
- class TracedModule(ScriptModule):
- _disable_script_meta = True
- def __init__(self, orig, id_set=None, _compilation_unit=None):
- # XXX: orig can be a nn.Module or a function!
- super().__init__()
- assert isinstance(orig, torch.nn.Module)
- # Copy a subset of `orig` to a temporary nn.Module.
- # This is a way to customize what will actually get compiled by create_script_module
- id_set = set()
- # This allows us to preserve the original module's qualified name by defining a new
- # type with the attribute _jit_override_qualname. In torch._jit_internal._qualified_name
- # we have a special case that will look up this attribute to override whatever qualname
- # we would get from the python type system
- class QualnameWrapper(torch.nn.Module):
- pass
- QualnameWrapper._jit_override_qualname = torch._jit_internal._qualified_name( # type: ignore[attr-defined]
- type(orig)
- )
- tmp_module = QualnameWrapper()
- def check_unique(param):
- if param in id_set:
- raise ValueError(
- "TracedModules don't support parameter sharing between modules"
- )
- id_set.add(param)
- tmp_module.training = orig.training
- for name, param in orig._parameters.items():
- if param is not None:
- tmp_module._parameters[name] = param
- check_unique(param)
- for name, buf in orig._buffers.items():
- if buf is not None:
- tmp_module._buffers[name] = buf
- check_unique(buf)
- for name, val in orig.__dict__.items():
- if (
- torch._C._jit_is_script_object(val)
- and name not in orig._parameters
- and name not in orig._buffers
- ):
- setattr(tmp_module, name, val)
- if orig._backward_hooks:
- raise ValueError(
- "Modules that have backward hooks assigned can't be compiled: "
- + str(orig)
- )
- for name, submodule in orig._modules.items():
- if submodule is None:
- continue
- tmp_module._modules[name] = make_module(
- submodule, TracedModule, _compilation_unit=None
- )
- script_module = torch.jit._recursive.create_script_module(
- tmp_module, lambda module: (), share_types=False, is_tracing=True
- )
- self.__dict__["_name"] = type(orig).__name__
- self.__dict__["_actual_script_module"] = script_module
- for name in ("_parameters", "_buffers", "_modules", "training"):
- delattr(self, name)
- def forward(self, *args, **kwargs):
- raise RuntimeError("Trace submodules cannot be called.")
- def __getattr__(self, attr):
- if "_actual_script_module" not in self.__dict__:
- return super().__getattr__(attr)
- return getattr(self._actual_script_module, attr)
- def __setattr__(self, attr, value):
- if "_actual_script_module" not in self.__dict__:
- return super().__setattr__(attr, value)
- setattr(self._actual_script_module, attr, value)
- def _get_name(self):
- return self._name
- def extra_repr(self):
- return "original_name={}".format(self._name)
- class TopLevelTracedModule(TracedModule):
- forward: Callable[..., Any] = _CachedForward() # type: ignore[assignment]
- def _reconstruct(self, cpp_module):
- """
- Re-construct an instance of TopLevelTracedModule using an instance of a C++ module.
- Args:
- cpp_module: The C++ module that this TopLevelTracedModule will be rebuilt around.
- """
- self.__dict__["_actual_script_module"]._reconstruct(cpp_module)
- def _script_if_tracing(fn):
- @functools.wraps(fn)
- def wrapper(*args, **kwargs):
- if not is_tracing():
- # Not tracing, don't do anything
- return fn(*args, **kwargs)
- compiled_fn = script(wrapper.__original_fn) # type: ignore[attr-defined]
- return compiled_fn(*args, **kwargs)
- wrapper.__original_fn = fn # type: ignore[attr-defined]
- wrapper.__script_if_tracing_wrapper = True # type: ignore[attr-defined]
- return wrapper
- def _get_trace_graph(f, args=(), kwargs=None, strict=True, _force_outplace=False,
- return_inputs=False, _return_inputs_states=False):
- """
- .. warning::
- This function is internal-only and should only be used by the ONNX
- exporter. If you are trying to get a graph through tracing, please go
- through the public API instead::
- trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
- trace_graph = trace.graph
- Trace a function or model, returning a tuple consisting of the both the
- *trace* of an execution, as well as the original return value. If return_inputs,
- also returns the trace inputs as part of the tuple
- Tracing is guaranteed not to change the semantics of the function/module
- that is traced.
- Args:
- f (torch.nn.Module or function): the function or module
- to be traced.
- args (tuple or Tensor): the positional arguments to pass to the
- function/module to be traced. A non-tuple is assumed to
- be a single positional argument to be passed to the model.
- kwargs (dict): the keyword arguments to pass to the function/module
- to be traced.
- Example (trace a cell):
- .. testcode::
- trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
- """
- if kwargs is None:
- kwargs = {}
- if not isinstance(args, tuple):
- args = (args,)
- outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
- return outs
|