123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291 |
- """Serialization
- This module contains functionality for serializing TorchScript modules, notably:
- * torch.jit.save
- * torch.jit.load
- This is not intended to be imported directly; please use the exposed
- functionalities in `torch.jit`.
- """
- import os
- import pathlib
- import torch
- from torch.jit._recursive import wrap_cpp_module
- from torch.serialization import validate_cuda_device
- def save(m, f, _extra_files=None):
- r"""
- Save an offline version of this module for use in a separate process. The
- saved module serializes all of the methods, submodules, parameters, and
- attributes of this module. It can be loaded into the C++ API using
- ``torch::jit::load(filename)`` or into the Python API with
- :func:`torch.jit.load <torch.jit.load>`.
- To be able to save a module, it must not make any calls to native Python
- functions. This means that all submodules must be subclasses of
- :class:`ScriptModule` as well.
- .. DANGER::
- All modules, no matter their device, are always loaded onto the CPU
- during loading. This is different from :func:`torch.load`'s semantics
- and may change in the future.
- Args:
- m: A :class:`ScriptModule` to save.
- f: A file-like object (has to implement write and flush) or a string
- containing a file name.
- _extra_files: Map from filename to contents which will be stored as part of `f`.
- .. note::
- torch.jit.save attempts to preserve the behavior of some operators
- across versions. For example, dividing two integer tensors in
- PyTorch 1.5 performed floor division, and if the module
- containing that code is saved in PyTorch 1.5 and loaded in PyTorch 1.6
- its division behavior will be preserved. The same module saved in
- PyTorch 1.6 will fail to load in PyTorch 1.5, however, since the
- behavior of division changed in 1.6, and 1.5 does not know how to
- replicate the 1.6 behavior.
- Example:
- .. testcode::
- import torch
- import io
- class MyModule(torch.nn.Module):
- def forward(self, x):
- return x + 10
- m = torch.jit.script(MyModule())
- # Save to file
- torch.jit.save(m, 'scriptmodule.pt')
- # This line is equivalent to the previous
- m.save("scriptmodule.pt")
- # Save to io.BytesIO buffer
- buffer = io.BytesIO()
- torch.jit.save(m, buffer)
- # Save with extra files
- extra_files = {'foo.txt': b'bar'}
- torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)
- """
- if _extra_files is None:
- _extra_files = {}
- if isinstance(f, (str, pathlib.Path)):
- m.save(f, _extra_files=_extra_files)
- else:
- ret = m.save_to_buffer(_extra_files=_extra_files)
- f.write(ret)
- def load(f, map_location=None, _extra_files=None, _restore_shapes=False):
- r"""
- Load a :class:`ScriptModule` or :class:`ScriptFunction` previously
- saved with :func:`torch.jit.save <torch.jit.save>`
- All previously saved modules, no matter their device, are first loaded onto CPU,
- and then are moved to the devices they were saved from. If this fails (e.g.
- because the run time system doesn't have certain devices), an exception is
- raised.
- Args:
- f: a file-like object (has to implement read, readline, tell, and seek),
- or a string containing a file name
- map_location (string or torch.device): A simplified version of
- ``map_location`` in `torch.jit.save` used to dynamically remap
- storages to an alternative set of devices.
- _extra_files (dictionary of filename to content): The extra
- filenames given in the map would be loaded and their content
- would be stored in the provided map.
- _restore_shapes (bool): Whether or not to retrace the module on load using stored inputs
- Returns:
- A :class:`ScriptModule` object.
- Example:
- .. testcode::
- import torch
- import io
- torch.jit.load('scriptmodule.pt')
- # Load ScriptModule from io.BytesIO object
- with open('scriptmodule.pt', 'rb') as f:
- buffer = io.BytesIO(f.read())
- # Load all tensors to the original device
- torch.jit.load(buffer)
- # Load all tensors onto CPU, using a device
- buffer.seek(0)
- torch.jit.load(buffer, map_location=torch.device('cpu'))
- # Load all tensors onto CPU, using a string
- buffer.seek(0)
- torch.jit.load(buffer, map_location='cpu')
- # Load with extra files.
- extra_files = {'foo.txt': ''} # values will be replaced with data
- torch.jit.load('scriptmodule.pt', _extra_files=extra_files)
- print(extra_files['foo.txt'])
- .. testoutput::
- :hide:
- ...
- .. testcleanup::
- import os
- os.remove("scriptmodule.pt")
- """
- if isinstance(f, str):
- if not os.path.exists(f): # type: ignore[type-var]
- raise ValueError("The provided filename {} does not exist".format(f)) # type: ignore[str-bytes-safe]
- if os.path.isdir(f):
- raise ValueError("The provided filename {} is a directory".format(f)) # type: ignore[str-bytes-safe]
- map_location = validate_map_location(map_location)
- if _extra_files is None:
- _extra_files = {}
- cu = torch._C.CompilationUnit()
- if isinstance(f, (str, pathlib.Path)):
- cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files, _restore_shapes) # type: ignore[call-arg]
- else:
- cpp_module = torch._C.import_ir_module_from_buffer(
- cu, f.read(), map_location, _extra_files, _restore_shapes
- ) # type: ignore[call-arg]
- # TODO: Pretty sure this approach loses ConstSequential status and such
- return wrap_cpp_module(cpp_module)
- def validate_map_location(map_location=None):
- if isinstance(map_location, str):
- map_location = torch.device(map_location)
- elif not (map_location is None or isinstance(map_location, torch.device)):
- raise ValueError(
- "map_location should be either None, string or torch.device, "
- "but got type: " + str(type(map_location))
- )
- if str(map_location).startswith("cuda"):
- validate_cuda_device(map_location)
- return map_location
- def get_ff_module():
- try:
- import torch._C_flatbuffer as ff
- return ff
- except ImportError:
- print("Please include //caffe2:_C_flatbuffer as dependency.")
- raise
- def jit_module_from_flatbuffer(f):
- ff = get_ff_module()
- if isinstance(f, str):
- if not os.path.exists(f): # type: ignore[type-var]
- raise ValueError("The provided filename {} does not exist".format(f)) # type: ignore[str-bytes-safe]
- if os.path.isdir(f):
- raise ValueError("The provided filename {} is a directory".format(f)) # type: ignore[str-bytes-safe]
- if isinstance(f, (str, pathlib.Path)):
- f = str(f)
- return wrap_cpp_module(ff._load_jit_module_from_file(f))
- else:
- return wrap_cpp_module(ff._load_jit_module_from_bytes(f.read()))
- def save_jit_module_to_flatbuffer(m, f, _extra_files=None):
- r"""
- Save an offline version of this module for use in a separate process. The
- saved module serializes all of the methods, submodules, parameters, and
- attributes of this module. It can be loaded into the C++ API using
- ``torch::jit::load_jit_module_from_file(filename)`` or into the Python API with
- :func:`torch.jit.jit_module_from_flatbuffer<torch.jit.jit_module_from_flatbuffer>`.
- To be able to save a module, it must not make any calls to native Python
- functions. This means that all submodules must be subclasses of
- :class:`ScriptModule` as well.
- .. DANGER::
- All modules, no matter their device, are always loaded onto the CPU
- during loading. This is different from :func:`torch.load`'s semantics
- and may change in the future.
- Args:
- m: A :class:`ScriptModule` to save.
- f: A string for file path
- Example:
- .. testcode::
- import torch
- import io
- class MyModule(torch.nn.Module):
- def forward(self, x):
- return x + 10
- m = torch.jit.script(MyModule())
- # Save to file
- torch.jit.save_jit_module_to_flatbuffer(m, 'scriptmodule.ff')
- """
- extra_files = _extra_files
- if extra_files is None:
- extra_files = {}
- ff = get_ff_module()
- if isinstance(f, (str, pathlib.Path)):
- f = str(f)
- ff._save_jit_module(m._c, f, extra_files)
- else:
- s = ff._save_jit_module_to_bytes(m._c, extra_files)
- f.write(s)
- def get_flatbuffer_module_info(path_or_file):
- r"""Get some information regarding a model file in flatbuffer format.
- Args:
- path_or_file: Either str, Path or file like object (BytesIO OK).
- If it's str or Path, we will read the file referenced by that
- path as Bytes.
- Returns:
- A dict with metadata on what that file contains, currently looks like
- this:
- {
- 'bytecode_version': 4, # int
- 'operator_version': 4, # int
- 'function_names': {
- '__torch__.___torch_mangle_0.Foo.forward'}, # set
- 'type_names': set(), # set
- 'opname_to_num_args': {'aten::linear': 3} # Dict[str, int]
- }
- """
- ff = get_ff_module()
- if isinstance(path_or_file, (str, pathlib.Path)):
- with open(path_or_file, "rb") as f:
- all_bytes = f.read()
- else:
- all_bytes = path_or_file.read()
- return ff._get_module_info_from_flatbuffer(all_bytes)
|