_serialization.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. """Serialization
  2. This module contains functionality for serializing TorchScript modules, notably:
  3. * torch.jit.save
  4. * torch.jit.load
  5. This is not intended to be imported directly; please use the exposed
  6. functionalities in `torch.jit`.
  7. """
  8. import os
  9. import pathlib
  10. import torch
  11. from torch.jit._recursive import wrap_cpp_module
  12. from torch.serialization import validate_cuda_device
  13. def save(m, f, _extra_files=None):
  14. r"""
  15. Save an offline version of this module for use in a separate process. The
  16. saved module serializes all of the methods, submodules, parameters, and
  17. attributes of this module. It can be loaded into the C++ API using
  18. ``torch::jit::load(filename)`` or into the Python API with
  19. :func:`torch.jit.load <torch.jit.load>`.
  20. To be able to save a module, it must not make any calls to native Python
  21. functions. This means that all submodules must be subclasses of
  22. :class:`ScriptModule` as well.
  23. .. DANGER::
  24. All modules, no matter their device, are always loaded onto the CPU
  25. during loading. This is different from :func:`torch.load`'s semantics
  26. and may change in the future.
  27. Args:
  28. m: A :class:`ScriptModule` to save.
  29. f: A file-like object (has to implement write and flush) or a string
  30. containing a file name.
  31. _extra_files: Map from filename to contents which will be stored as part of `f`.
  32. .. note::
  33. torch.jit.save attempts to preserve the behavior of some operators
  34. across versions. For example, dividing two integer tensors in
  35. PyTorch 1.5 performed floor division, and if the module
  36. containing that code is saved in PyTorch 1.5 and loaded in PyTorch 1.6
  37. its division behavior will be preserved. The same module saved in
  38. PyTorch 1.6 will fail to load in PyTorch 1.5, however, since the
  39. behavior of division changed in 1.6, and 1.5 does not know how to
  40. replicate the 1.6 behavior.
  41. Example:
  42. .. testcode::
  43. import torch
  44. import io
  45. class MyModule(torch.nn.Module):
  46. def forward(self, x):
  47. return x + 10
  48. m = torch.jit.script(MyModule())
  49. # Save to file
  50. torch.jit.save(m, 'scriptmodule.pt')
  51. # This line is equivalent to the previous
  52. m.save("scriptmodule.pt")
  53. # Save to io.BytesIO buffer
  54. buffer = io.BytesIO()
  55. torch.jit.save(m, buffer)
  56. # Save with extra files
  57. extra_files = {'foo.txt': b'bar'}
  58. torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)
  59. """
  60. if _extra_files is None:
  61. _extra_files = {}
  62. if isinstance(f, (str, pathlib.Path)):
  63. m.save(f, _extra_files=_extra_files)
  64. else:
  65. ret = m.save_to_buffer(_extra_files=_extra_files)
  66. f.write(ret)
  67. def load(f, map_location=None, _extra_files=None, _restore_shapes=False):
  68. r"""
  69. Load a :class:`ScriptModule` or :class:`ScriptFunction` previously
  70. saved with :func:`torch.jit.save <torch.jit.save>`
  71. All previously saved modules, no matter their device, are first loaded onto CPU,
  72. and then are moved to the devices they were saved from. If this fails (e.g.
  73. because the run time system doesn't have certain devices), an exception is
  74. raised.
  75. Args:
  76. f: a file-like object (has to implement read, readline, tell, and seek),
  77. or a string containing a file name
  78. map_location (string or torch.device): A simplified version of
  79. ``map_location`` in `torch.jit.save` used to dynamically remap
  80. storages to an alternative set of devices.
  81. _extra_files (dictionary of filename to content): The extra
  82. filenames given in the map would be loaded and their content
  83. would be stored in the provided map.
  84. _restore_shapes (bool): Whether or not to retrace the module on load using stored inputs
  85. Returns:
  86. A :class:`ScriptModule` object.
  87. Example:
  88. .. testcode::
  89. import torch
  90. import io
  91. torch.jit.load('scriptmodule.pt')
  92. # Load ScriptModule from io.BytesIO object
  93. with open('scriptmodule.pt', 'rb') as f:
  94. buffer = io.BytesIO(f.read())
  95. # Load all tensors to the original device
  96. torch.jit.load(buffer)
  97. # Load all tensors onto CPU, using a device
  98. buffer.seek(0)
  99. torch.jit.load(buffer, map_location=torch.device('cpu'))
  100. # Load all tensors onto CPU, using a string
  101. buffer.seek(0)
  102. torch.jit.load(buffer, map_location='cpu')
  103. # Load with extra files.
  104. extra_files = {'foo.txt': ''} # values will be replaced with data
  105. torch.jit.load('scriptmodule.pt', _extra_files=extra_files)
  106. print(extra_files['foo.txt'])
  107. .. testoutput::
  108. :hide:
  109. ...
  110. .. testcleanup::
  111. import os
  112. os.remove("scriptmodule.pt")
  113. """
  114. if isinstance(f, str):
  115. if not os.path.exists(f): # type: ignore[type-var]
  116. raise ValueError("The provided filename {} does not exist".format(f)) # type: ignore[str-bytes-safe]
  117. if os.path.isdir(f):
  118. raise ValueError("The provided filename {} is a directory".format(f)) # type: ignore[str-bytes-safe]
  119. map_location = validate_map_location(map_location)
  120. if _extra_files is None:
  121. _extra_files = {}
  122. cu = torch._C.CompilationUnit()
  123. if isinstance(f, (str, pathlib.Path)):
  124. cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files, _restore_shapes) # type: ignore[call-arg]
  125. else:
  126. cpp_module = torch._C.import_ir_module_from_buffer(
  127. cu, f.read(), map_location, _extra_files, _restore_shapes
  128. ) # type: ignore[call-arg]
  129. # TODO: Pretty sure this approach loses ConstSequential status and such
  130. return wrap_cpp_module(cpp_module)
  131. def validate_map_location(map_location=None):
  132. if isinstance(map_location, str):
  133. map_location = torch.device(map_location)
  134. elif not (map_location is None or isinstance(map_location, torch.device)):
  135. raise ValueError(
  136. "map_location should be either None, string or torch.device, "
  137. "but got type: " + str(type(map_location))
  138. )
  139. if str(map_location).startswith("cuda"):
  140. validate_cuda_device(map_location)
  141. return map_location
  142. def get_ff_module():
  143. try:
  144. import torch._C_flatbuffer as ff
  145. return ff
  146. except ImportError:
  147. print("Please include //caffe2:_C_flatbuffer as dependency.")
  148. raise
  149. def jit_module_from_flatbuffer(f):
  150. ff = get_ff_module()
  151. if isinstance(f, str):
  152. if not os.path.exists(f): # type: ignore[type-var]
  153. raise ValueError("The provided filename {} does not exist".format(f)) # type: ignore[str-bytes-safe]
  154. if os.path.isdir(f):
  155. raise ValueError("The provided filename {} is a directory".format(f)) # type: ignore[str-bytes-safe]
  156. if isinstance(f, (str, pathlib.Path)):
  157. f = str(f)
  158. return wrap_cpp_module(ff._load_jit_module_from_file(f))
  159. else:
  160. return wrap_cpp_module(ff._load_jit_module_from_bytes(f.read()))
  161. def save_jit_module_to_flatbuffer(m, f, _extra_files=None):
  162. r"""
  163. Save an offline version of this module for use in a separate process. The
  164. saved module serializes all of the methods, submodules, parameters, and
  165. attributes of this module. It can be loaded into the C++ API using
  166. ``torch::jit::load_jit_module_from_file(filename)`` or into the Python API with
  167. :func:`torch.jit.jit_module_from_flatbuffer<torch.jit.jit_module_from_flatbuffer>`.
  168. To be able to save a module, it must not make any calls to native Python
  169. functions. This means that all submodules must be subclasses of
  170. :class:`ScriptModule` as well.
  171. .. DANGER::
  172. All modules, no matter their device, are always loaded onto the CPU
  173. during loading. This is different from :func:`torch.load`'s semantics
  174. and may change in the future.
  175. Args:
  176. m: A :class:`ScriptModule` to save.
  177. f: A string for file path
  178. Example:
  179. .. testcode::
  180. import torch
  181. import io
  182. class MyModule(torch.nn.Module):
  183. def forward(self, x):
  184. return x + 10
  185. m = torch.jit.script(MyModule())
  186. # Save to file
  187. torch.jit.save_jit_module_to_flatbuffer(m, 'scriptmodule.ff')
  188. """
  189. extra_files = _extra_files
  190. if extra_files is None:
  191. extra_files = {}
  192. ff = get_ff_module()
  193. if isinstance(f, (str, pathlib.Path)):
  194. f = str(f)
  195. ff._save_jit_module(m._c, f, extra_files)
  196. else:
  197. s = ff._save_jit_module_to_bytes(m._c, extra_files)
  198. f.write(s)
  199. def get_flatbuffer_module_info(path_or_file):
  200. r"""Get some information regarding a model file in flatbuffer format.
  201. Args:
  202. path_or_file: Either str, Path or file like object (BytesIO OK).
  203. If it's str or Path, we will read the file referenced by that
  204. path as Bytes.
  205. Returns:
  206. A dict with metadata on what that file contains, currently looks like
  207. this:
  208. {
  209. 'bytecode_version': 4, # int
  210. 'operator_version': 4, # int
  211. 'function_names': {
  212. '__torch__.___torch_mangle_0.Foo.forward'}, # set
  213. 'type_names': set(), # set
  214. 'opname_to_num_args': {'aten::linear': 3} # Dict[str, int]
  215. }
  216. """
  217. ff = get_ff_module()
  218. if isinstance(path_or_file, (str, pathlib.Path)):
  219. with open(path_or_file, "rb") as f:
  220. all_bytes = f.read()
  221. else:
  222. all_bytes = path_or_file.read()
  223. return ff._get_module_info_from_flatbuffer(all_bytes)