123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219 |
- """Freezing
- This is not intended to be imported directly; please use the exposed
- functionalities in `torch.jit`.
- """
- from typing import Optional, List
- import torch
- from torch.jit._script import RecursiveScriptModule, ScriptModule
- def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics: bool = True):
- r"""
- Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned
- module's submodules, parameters, and attributes as constants in the TorchScript IR Graph.
- By default, `forward` will be preserved, as well as attributes & methods specified in
- `preserved_attrs`. Additionally, any attribute that is modified within a preserved
- method will be preserved.
- Freezing currently only accepts ScriptModules that are in eval mode.
- Freezing applies generic optimization that will speed up your model regardless of machine.
- To further optimize using server-specific settings, run `optimize_for_inference` after
- freezing.
- Args:
- mod (:class:`ScriptModule`): a module to be frozen
- preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method.
- Attributes modified in preserved methods will also be preserved.
- optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly
- preserve numerics. Full details of optimization can be found at `torch.jit.run_frozen_optimizations`.
- Returns:
- Frozen :class:`ScriptModule`.
- Example (Freezing a simple module with a Parameter):
- .. testcode::
- import torch
- class MyModule(torch.nn.Module):
- def __init__(self, N, M):
- super().__init__()
- self.weight = torch.nn.Parameter(torch.rand(N, M))
- self.linear = torch.nn.Linear(N, M)
- def forward(self, input):
- output = self.weight.mm(input)
- output = self.linear(output)
- return output
- scripted_module = torch.jit.script(MyModule(2, 3).eval())
- frozen_module = torch.jit.freeze(scripted_module)
- # parameters have been removed and inlined into the Graph as constants
- assert len(list(frozen_module.named_parameters())) == 0
- # See the compiled graph as Python code
- print(frozen_module.code)
- Example (Freezing a module with preserved attributes)
- .. testcode::
- import torch
- class MyModule2(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.modified_tensor = torch.tensor(10.)
- self.version = 1
- def forward(self, input):
- self.modified_tensor += 1
- return input + self.modified_tensor
- scripted_module = torch.jit.script(MyModule2().eval())
- frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"])
- # we've manually preserved `version`, so it still exists on the frozen module and can be modified
- assert frozen_module.version == 1
- frozen_module.version = 2
- # `modified_tensor` is detected as being mutated in the forward, so freezing preserves
- # it to retain model semantics
- assert frozen_module(torch.tensor(1)) == torch.tensor(12)
- # now that we've run it once, the next result will be incremented by one
- assert frozen_module(torch.tensor(1)) == torch.tensor(13)
- Note:
- Freezing submodule attributes is also supported:
- frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["submodule.version"])
- Note:
- If you're not sure why an attribute is not being inlined as a constant, you can run
- `dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the
- attribute is being modified.
- Note:
- Because freezing makes weights constants and removes module hierarchy, `to` and other
- nn.Module methods to manipulate device or dtype no longer work. As a workaround,
- You can remap devices by specifying `map_location` in `torch.jit.load`, however
- device-specific logic may have been baked into the model.
- """
- if not isinstance(mod, ScriptModule):
- raise RuntimeError(
- "Freezing expects a ScriptModule as input. "
- "Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'."
- )
- if mod.training:
- raise RuntimeError(
- "Freezing is currently only implemented for modules in eval mode. "
- "Please call .eval() on your module before freezing."
- )
- preserved_attrs = preserved_attrs if preserved_attrs is not None else []
- out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
- RecursiveScriptModule._finalize_scriptmodule(out)
- preserved_methods = [x for x in preserved_attrs if mod._c._has_method(x)]
- run_frozen_optimizations(out, optimize_numerics, preserved_methods)
- return out
- def run_frozen_optimizations(
- mod, optimize_numerics: bool = True, preserved_methods: Optional[List[str]] = None
- ):
- r"""
- Runs a series of optimizations looking for patterns that occur in frozen graphs.
- The current set of optimizations includes:
- - Dropout Removal
- - Pretranspose Linear Layers
- - Concat Linear Layers with same input Tensor
- - Conv -> Batchnorm folding
- - Conv -> Add/Sub folding
- - Conv -> Mul/Div folding
- Args:
- mod (:class:`ScriptModule`): a frozen module to be optimized
- optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly
- preserve numerics. These optimizations preserve default rtol and atol of `torch.testing.assert_close`
- when applied on a single transformation, however in a module where many transformations are applied
- the rtol or atol may no longer fall within the default `assert_close` tolerance. Conv -> Batchnorm folding,
- Conv-Add/Sub, and Conv -> Mul/Div folding all may alter numerics.
- Returns:
- None
- Note:
- In rare occassions, this can result in slower execution.
- Example (Freezing a module with Conv->Batchnorm)
- .. code-block:: python
- import torch
- in_channels, out_channels = 3, 32
- conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
- bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
- mod = torch.nn.Sequential(conv, bn)
- # set optimize to False here, by default freezing runs run_frozen_optimizations
- frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
- # inspect frozen mod
- assert "batch_norm" in str(frozen_mod.graph)
- torch.jit.run_frozen_optimizations(frozen_mod)
- assert "batch_norm" not in str(frozen_mod.graph)
- """
- if mod._c._has_method("forward"):
- torch._C._jit_pass_optimize_frozen_graph(mod.graph, optimize_numerics)
- if preserved_methods is None:
- preserved_methods = []
- for method in preserved_methods:
- torch._C._jit_pass_optimize_frozen_graph(
- mod.__getattr__(method).graph, optimize_numerics
- )
- def optimize_for_inference(mod: ScriptModule, other_methods: Optional[List[str]] = None) -> ScriptModule:
- """
- Performs a set of optimization passes to optimize a model for the
- purposes of inference. If the model is not already frozen, optimize_for_inference
- will invoke `torch.jit.freeze` automatically.
- In addition to generic optimizations that should speed up your model regardless
- of environment, prepare for inference will also bake in build specific settings
- such as the presence of CUDNN or MKLDNN, and may in the future make transformations
- which speed things up on one machine but slow things down on another. Accordingly,
- serialization is not implemented following invoking `optimize_for_inference` and
- is not guaranteed.
- This is still in prototype, and may have the potential to slow down your model.
- Primary use cases that have been targeted so far have been vision models on cpu
- and gpu to a lesser extent.
- Example (optimizing a module with Conv->Batchnorm)::
- import torch
- in_channels, out_channels = 3, 32
- conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
- bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
- mod = torch.nn.Sequential(conv, bn)
- frozen_mod = torch.jit.optimize_for_inference(torch.jit.script(mod.eval()))
- assert "batch_norm" not in str(frozen_mod.graph)
- # if built with MKLDNN, convolution will be run with MKLDNN weights
- assert "MKLDNN" in frozen_mod.graph
- """
- if not isinstance(mod, ScriptModule):
- raise RuntimeError(
- "optimize_for_inference expects a ScriptModule as input. "
- "Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'.")
- if other_methods is None:
- other_methods = []
- if hasattr(mod, "training"):
- mod = freeze(mod.eval(), preserved_attrs=other_methods)
- torch._C._jit_pass_optimize_for_inference(mod._c, other_methods)
- return mod
|