__init__.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. #!/usr/bin/env python3
  2. """
  3. model_dump: a one-stop shop for TorchScript model inspection.
  4. The goal of this tool is to provide a simple way to extract lots of
  5. useful information from a TorchScript model and make it easy for humans
  6. to consume. It (mostly) replaces zipinfo, common uses of show_pickle,
  7. and various ad-hoc analysis notebooks.
  8. The tool extracts information from the model and serializes it as JSON.
  9. That JSON can then be rendered by an HTML+JS page, either by
  10. loading the JSON over HTTP or producing a fully self-contained page
  11. with all of the code and data burned-in.
  12. """
  13. # Maintainer notes follow.
  14. """
  15. The implementation strategy has tension between 3 goals:
  16. - Small file size.
  17. - Fully self-contained.
  18. - Easy, modern JS environment.
  19. Using Preact and HTM achieves 1 and 2 with a decent result for 3.
  20. However, the models I tested with result in ~1MB JSON output,
  21. so even using something heavier like full React might be tolerable
  22. if the build process can be worked out.
  23. One principle I have followed that I think is very beneficial
  24. is to keep the JSON data as close as possible to the model
  25. and do most of the rendering logic on the client.
  26. This makes for easier development (just refresh, usually),
  27. allows for more laziness and dynamism, and lets us add more
  28. views of the same data without bloating the HTML file.
  29. Currently, this code doesn't actually load the model or even
  30. depend on any part of PyTorch. I don't know if that's an important
  31. feature to maintain, but it's probably worth preserving the ability
  32. to run at least basic analysis on models that cannot be loaded.
  33. I think the easiest way to develop this code is to cd into model_dump and
  34. run "python -m http.server", then load http://localhost:8000/skeleton.html
  35. in the browser. In another terminal, run
  36. "python -m torch.utils.model_dump --style=json FILE > \
  37. torch/utils/model_dump/model_info.json"
  38. every time you update the Python code or model.
  39. When you update JS, just refresh.
  40. Possible improvements:
  41. - Fix various TODO comments in this file and the JS.
  42. - Make the HTML much less janky, especially the auxiliary data panel.
  43. - Make the auxiliary data panel start small, expand when
  44. data is available, and have a button to clear/contract.
  45. - Clean up the JS. There's a lot of copypasta because
  46. I don't really know how to use Preact.
  47. - Make the HTML render and work nicely inside a Jupyter notebook.
  48. - Add the ability for JS to choose the URL to load the JSON based
  49. on the page URL (query or hash). That way we could publish the
  50. inlined skeleton once and have it load various JSON blobs.
  51. - Add a button to expand all expandable sections so ctrl-F works well.
  52. - Add hyperlinking from data to code, and code to code.
  53. - Add hyperlinking from debug info to Diffusion.
  54. - Make small tensor contents available.
  55. - Do something nice for quantized models
  56. (they probably don't work at all right now).
  57. """
  58. import sys
  59. import os
  60. import io
  61. import pathlib
  62. import re
  63. import argparse
  64. import zipfile
  65. import json
  66. import pickle
  67. import pprint
  68. import urllib.parse
  69. from typing import (
  70. Dict,
  71. )
  72. import torch.utils.show_pickle
  73. DEFAULT_EXTRA_FILE_SIZE_LIMIT = 16 * 1024
  74. __all__ = ['get_storage_info', 'hierarchical_pickle', 'get_model_info', 'get_inline_skeleton',
  75. 'burn_in_info', 'get_info_and_burn_skeleton']
  76. def get_storage_info(storage):
  77. assert isinstance(storage, torch.utils.show_pickle.FakeObject)
  78. assert storage.module == "pers"
  79. assert storage.name == "obj"
  80. assert storage.state is None
  81. assert isinstance(storage.args, tuple)
  82. assert len(storage.args) == 1
  83. sa = storage.args[0]
  84. assert isinstance(sa, tuple)
  85. assert len(sa) == 5
  86. assert sa[0] == "storage"
  87. assert isinstance(sa[1], torch.utils.show_pickle.FakeClass)
  88. assert sa[1].module == "torch"
  89. assert sa[1].name.endswith("Storage")
  90. storage_info = [sa[1].name.replace("Storage", "")] + list(sa[2:])
  91. return storage_info
  92. def hierarchical_pickle(data):
  93. if isinstance(data, (bool, int, float, str, type(None))):
  94. return data
  95. if isinstance(data, list):
  96. return [hierarchical_pickle(d) for d in data]
  97. if isinstance(data, tuple):
  98. return {
  99. "__tuple_values__": hierarchical_pickle(list(data)),
  100. }
  101. if isinstance(data, dict):
  102. return {
  103. "__is_dict__": True,
  104. "keys": hierarchical_pickle(list(data.keys())),
  105. "values": hierarchical_pickle(list(data.values())),
  106. }
  107. if isinstance(data, torch.utils.show_pickle.FakeObject):
  108. typename = f"{data.module}.{data.name}"
  109. if (
  110. typename.startswith("__torch__.") or
  111. typename.startswith("torch.jit.LoweredWrapper.") or
  112. typename.startswith("torch.jit.LoweredModule.")
  113. ):
  114. assert data.args == ()
  115. return {
  116. "__module_type__": typename,
  117. "state": hierarchical_pickle(data.state),
  118. }
  119. if typename == "torch._utils._rebuild_tensor_v2":
  120. assert data.state is None
  121. if len(data.args) == 6:
  122. storage, offset, size, stride, requires_grad, hooks = data.args
  123. else:
  124. storage, offset, size, stride, requires_grad, hooks, metadata = data.args
  125. storage_info = get_storage_info(storage)
  126. return {"__tensor_v2__": [storage_info, offset, size, stride, requires_grad]}
  127. if typename == "torch._utils._rebuild_qtensor":
  128. assert data.state is None
  129. storage, offset, size, stride, quantizer, requires_grad, hooks = data.args
  130. storage_info = get_storage_info(storage)
  131. assert isinstance(quantizer, tuple)
  132. assert isinstance(quantizer[0], torch.utils.show_pickle.FakeClass)
  133. assert quantizer[0].module == "torch"
  134. if quantizer[0].name == "per_tensor_affine":
  135. assert len(quantizer) == 3
  136. assert isinstance(quantizer[1], float)
  137. assert isinstance(quantizer[2], int)
  138. quantizer_extra = list(quantizer[1:3])
  139. else:
  140. quantizer_extra = []
  141. quantizer_json = [quantizer[0].name] + quantizer_extra
  142. return {"__qtensor__": [storage_info, offset, size, stride, quantizer_json, requires_grad]}
  143. if typename == "torch.jit._pickle.restore_type_tag":
  144. assert data.state is None
  145. obj, typ = data.args
  146. assert isinstance(typ, str)
  147. return hierarchical_pickle(obj)
  148. if re.fullmatch(r"torch\.jit\._pickle\.build_[a-z]+list", typename):
  149. assert data.state is None
  150. ls, = data.args
  151. assert isinstance(ls, list)
  152. return hierarchical_pickle(ls)
  153. if typename == "torch.device":
  154. assert data.state is None
  155. name, = data.args
  156. assert isinstance(name, str)
  157. # Just forget that it was a device and return the name.
  158. return name
  159. if typename == "builtin.UnicodeDecodeError":
  160. assert data.state is None
  161. msg, = data.args
  162. assert isinstance(msg, str)
  163. # Hack: Pretend this is a module so we don't need custom serialization.
  164. # Hack: Wrap the message in a tuple so it looks like a nice state object.
  165. # TODO: Undo at least that second hack. We should support string states.
  166. return {
  167. "__module_type__": typename,
  168. "state": hierarchical_pickle((msg,)),
  169. }
  170. raise Exception(f"Can't prepare fake object of type for JS: {typename}")
  171. raise Exception(f"Can't prepare data of type for JS: {type(data)}")
  172. def get_model_info(
  173. path_or_file,
  174. title=None,
  175. extra_file_size_limit=DEFAULT_EXTRA_FILE_SIZE_LIMIT):
  176. """Get JSON-friendly information about a model.
  177. The result is suitable for being saved as model_info.json,
  178. or passed to burn_in_info.
  179. """
  180. if isinstance(path_or_file, os.PathLike):
  181. default_title = os.fspath(path_or_file)
  182. file_size = path_or_file.stat().st_size # type: ignore[attr-defined]
  183. elif isinstance(path_or_file, str):
  184. default_title = path_or_file
  185. file_size = pathlib.Path(path_or_file).stat().st_size
  186. else:
  187. default_title = "buffer"
  188. path_or_file.seek(0, io.SEEK_END)
  189. file_size = path_or_file.tell()
  190. path_or_file.seek(0)
  191. title = title or default_title
  192. with zipfile.ZipFile(path_or_file) as zf:
  193. path_prefix = None
  194. zip_files = []
  195. for zi in zf.infolist():
  196. prefix = re.sub("/.*", "", zi.filename)
  197. if path_prefix is None:
  198. path_prefix = prefix
  199. elif prefix != path_prefix:
  200. raise Exception(f"Mismatched prefixes: {path_prefix} != {prefix}")
  201. zip_files.append(dict(
  202. filename=zi.filename,
  203. compression=zi.compress_type,
  204. compressed_size=zi.compress_size,
  205. file_size=zi.file_size,
  206. ))
  207. assert path_prefix is not None
  208. version = zf.read(path_prefix + "/version").decode("utf-8").strip()
  209. def get_pickle(name):
  210. assert path_prefix is not None
  211. with zf.open(path_prefix + f"/{name}.pkl") as handle:
  212. raw = torch.utils.show_pickle.DumpUnpickler(handle, catch_invalid_utf8=True).load()
  213. return hierarchical_pickle(raw)
  214. model_data = get_pickle("data")
  215. constants = get_pickle("constants")
  216. # Intern strings that are likely to be re-used.
  217. # Pickle automatically detects shared structure,
  218. # so re-used strings are stored efficiently.
  219. # However, JSON has no way of representing this,
  220. # so we have to do it manually.
  221. interned_strings : Dict[str, int] = {}
  222. def ist(s):
  223. if s not in interned_strings:
  224. interned_strings[s] = len(interned_strings)
  225. return interned_strings[s]
  226. code_files = {}
  227. for zi in zf.infolist():
  228. if not zi.filename.endswith(".py"):
  229. continue
  230. with zf.open(zi) as handle:
  231. raw_code = handle.read()
  232. with zf.open(zi.filename + ".debug_pkl") as handle:
  233. raw_debug = handle.read()
  234. # Parse debug info and add begin/end markers if not present
  235. # to ensure that we cover the entire source code.
  236. debug_info_t = pickle.loads(raw_debug)
  237. text_table = None
  238. if (len(debug_info_t) == 3 and
  239. isinstance(debug_info_t[0], str) and
  240. debug_info_t[0] == 'FORMAT_WITH_STRING_TABLE'):
  241. _, text_table, content = debug_info_t
  242. def parse_new_format(line):
  243. # (0, (('', '', 0), 0, 0))
  244. num, ((text_indexes, fname_idx, offset), start, end), tag = line
  245. text = ''.join(text_table[x] for x in text_indexes) # type: ignore[index]
  246. fname = text_table[fname_idx] # type: ignore[index]
  247. return num, ((text, fname, offset), start, end), tag
  248. debug_info_t = map(parse_new_format, content)
  249. debug_info = list(debug_info_t)
  250. if not debug_info:
  251. debug_info.append((0, (('', '', 0), 0, 0)))
  252. if debug_info[-1][0] != len(raw_code):
  253. debug_info.append((len(raw_code), (('', '', 0), 0, 0)))
  254. code_parts = []
  255. for di, di_next in zip(debug_info, debug_info[1:]):
  256. start, source_range, *_ = di
  257. end = di_next[0]
  258. assert end > start
  259. source, s_start, s_end = source_range
  260. s_text, s_file, s_line = source
  261. # TODO: Handle this case better. TorchScript ranges are in bytes,
  262. # but JS doesn't really handle byte strings.
  263. # if bytes and chars are not equivalent for this string,
  264. # zero out the ranges so we don't highlight the wrong thing.
  265. if len(s_text) != len(s_text.encode("utf-8")):
  266. s_start = 0
  267. s_end = 0
  268. text = raw_code[start:end]
  269. code_parts.append([text.decode("utf-8"), ist(s_file), s_line, ist(s_text), s_start, s_end])
  270. code_files[zi.filename] = code_parts
  271. extra_files_json_pattern = re.compile(re.escape(path_prefix) + "/extra/.*\\.json")
  272. extra_files_jsons = {}
  273. for zi in zf.infolist():
  274. if not extra_files_json_pattern.fullmatch(zi.filename):
  275. continue
  276. if zi.file_size > extra_file_size_limit:
  277. continue
  278. with zf.open(zi) as handle:
  279. try:
  280. json_content = json.load(handle)
  281. extra_files_jsons[zi.filename] = json_content
  282. except json.JSONDecodeError:
  283. extra_files_jsons[zi.filename] = "INVALID JSON"
  284. always_render_pickles = {
  285. "bytecode.pkl",
  286. }
  287. extra_pickles = {}
  288. for zi in zf.infolist():
  289. if not zi.filename.endswith(".pkl"):
  290. continue
  291. with zf.open(zi) as handle:
  292. # TODO: handle errors here and just ignore the file?
  293. # NOTE: For a lot of these files (like bytecode),
  294. # we could get away with just unpickling, but this should be safer.
  295. obj = torch.utils.show_pickle.DumpUnpickler(handle, catch_invalid_utf8=True).load()
  296. buf = io.StringIO()
  297. pprint.pprint(obj, buf)
  298. contents = buf.getvalue()
  299. # Checked the rendered length instead of the file size
  300. # because pickles with shared structure can explode in size during rendering.
  301. if os.path.basename(zi.filename) not in always_render_pickles and \
  302. len(contents) > extra_file_size_limit:
  303. continue
  304. extra_pickles[zi.filename] = contents
  305. return {"model": dict(
  306. title=title,
  307. file_size=file_size,
  308. version=version,
  309. zip_files=zip_files,
  310. interned_strings=list(interned_strings),
  311. code_files=code_files,
  312. model_data=model_data,
  313. constants=constants,
  314. extra_files_jsons=extra_files_jsons,
  315. extra_pickles=extra_pickles,
  316. )}
  317. def get_inline_skeleton():
  318. """Get a fully-inlined skeleton of the frontend.
  319. The returned HTML page has no external network dependencies for code.
  320. It can load model_info.json over HTTP, or be passed to burn_in_info.
  321. """
  322. import importlib.resources
  323. skeleton = importlib.resources.read_text(__package__, "skeleton.html")
  324. js_code = importlib.resources.read_text(__package__, "code.js")
  325. for js_module in ["preact", "htm"]:
  326. js_lib = importlib.resources.read_binary(__package__, f"{js_module}.mjs")
  327. js_url = "data:application/javascript," + urllib.parse.quote(js_lib)
  328. js_code = js_code.replace(f"https://unpkg.com/{js_module}?module", js_url)
  329. skeleton = skeleton.replace(' src="./code.js">', ">\n" + js_code)
  330. return skeleton
  331. def burn_in_info(skeleton, info):
  332. """Burn model info into the HTML skeleton.
  333. The result will render the hard-coded model info and
  334. have no external network dependencies for code or data.
  335. """
  336. # Note that Python's json serializer does not escape slashes in strings.
  337. # Since we're inlining this JSON directly into a script tag, a string
  338. # containing "</script>" would end the script prematurely and
  339. # mess up our page. Unconditionally escape fixes that.
  340. return skeleton.replace(
  341. "BURNED_IN_MODEL_INFO = null",
  342. "BURNED_IN_MODEL_INFO = " + json.dumps(info, sort_keys=True).replace("/", "\\/"))
  343. def get_info_and_burn_skeleton(path_or_bytesio, **kwargs):
  344. model_info = get_model_info(path_or_bytesio, **kwargs)
  345. skeleton = get_inline_skeleton()
  346. page = burn_in_info(skeleton, model_info)
  347. return page
  348. def main(argv, *, stdout=None):
  349. parser = argparse.ArgumentParser()
  350. parser.add_argument("--style", choices=["json", "html"])
  351. parser.add_argument("--title")
  352. parser.add_argument("model")
  353. args = parser.parse_args(argv[1:])
  354. info = get_model_info(args.model, title=args.title)
  355. output = stdout or sys.stdout
  356. if args.style == "json":
  357. output.write(json.dumps(info, sort_keys=True) + "\n")
  358. elif args.style == "html":
  359. skeleton = get_inline_skeleton()
  360. page = burn_in_info(skeleton, info)
  361. output.write(page)
  362. else:
  363. raise Exception("Invalid style")