show_pickle.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. #!/usr/bin/env python3
  2. import sys
  3. import pickle
  4. import struct
  5. import pprint
  6. import zipfile
  7. import fnmatch
  8. from typing import Any, IO, BinaryIO, Union
  9. __all__ = ["FakeObject", "FakeClass", "DumpUnpickler", "main"]
  10. class FakeObject:
  11. def __init__(self, module, name, args):
  12. self.module = module
  13. self.name = name
  14. self.args = args
  15. # NOTE: We don't distinguish between state never set and state set to None.
  16. self.state = None
  17. def __repr__(self):
  18. state_str = "" if self.state is None else f"(state={self.state!r})"
  19. return f"{self.module}.{self.name}{self.args!r}{state_str}"
  20. def __setstate__(self, state):
  21. self.state = state
  22. @staticmethod
  23. def pp_format(printer, obj, stream, indent, allowance, context, level):
  24. if not obj.args and obj.state is None:
  25. stream.write(repr(obj))
  26. return
  27. if obj.state is None:
  28. stream.write(f"{obj.module}.{obj.name}")
  29. printer._format(obj.args, stream, indent + 1, allowance + 1, context, level)
  30. return
  31. if not obj.args:
  32. stream.write(f"{obj.module}.{obj.name}()(state=\n")
  33. indent += printer._indent_per_level
  34. stream.write(" " * indent)
  35. printer._format(obj.state, stream, indent, allowance + 1, context, level + 1)
  36. stream.write(")")
  37. return
  38. raise Exception("Need to implement")
  39. class FakeClass:
  40. def __init__(self, module, name):
  41. self.module = module
  42. self.name = name
  43. self.__new__ = self.fake_new # type: ignore[assignment]
  44. def __repr__(self):
  45. return f"{self.module}.{self.name}"
  46. def __call__(self, *args):
  47. return FakeObject(self.module, self.name, args)
  48. def fake_new(self, *args):
  49. return FakeObject(self.module, self.name, args[1:])
  50. class DumpUnpickler(pickle._Unpickler): # type: ignore[name-defined]
  51. def __init__(
  52. self,
  53. file,
  54. *,
  55. catch_invalid_utf8=False,
  56. **kwargs):
  57. super().__init__(file, **kwargs)
  58. self.catch_invalid_utf8 = catch_invalid_utf8
  59. def find_class(self, module, name):
  60. return FakeClass(module, name)
  61. def persistent_load(self, pid):
  62. return FakeObject("pers", "obj", (pid,))
  63. dispatch = dict(pickle._Unpickler.dispatch) # type: ignore[attr-defined]
  64. # Custom objects in TorchScript are able to return invalid UTF-8 strings
  65. # from their pickle (__getstate__) functions. Install a custom loader
  66. # for strings that catches the decode exception and replaces it with
  67. # a sentinel object.
  68. def load_binunicode(self):
  69. strlen, = struct.unpack("<I", self.read(4)) # type: ignore[attr-defined]
  70. if strlen > sys.maxsize:
  71. raise Exception("String too long.")
  72. str_bytes = self.read(strlen) # type: ignore[attr-defined]
  73. obj: Any
  74. try:
  75. obj = str(str_bytes, "utf-8", "surrogatepass")
  76. except UnicodeDecodeError as exn:
  77. if not self.catch_invalid_utf8:
  78. raise
  79. obj = FakeObject("builtin", "UnicodeDecodeError", (str(exn),))
  80. self.append(obj) # type: ignore[attr-defined]
  81. dispatch[pickle.BINUNICODE[0]] = load_binunicode # type: ignore[assignment]
  82. @classmethod
  83. def dump(cls, in_stream, out_stream):
  84. value = cls(in_stream).load()
  85. pprint.pprint(value, stream=out_stream)
  86. return value
  87. def main(argv, output_stream=None):
  88. if len(argv) != 2:
  89. # Don't spam stderr if not using stdout.
  90. if output_stream is not None:
  91. raise Exception("Pass argv of length 2.")
  92. sys.stderr.write("usage: show_pickle PICKLE_FILE\n")
  93. sys.stderr.write(" PICKLE_FILE can be any of:\n")
  94. sys.stderr.write(" path to a pickle file\n")
  95. sys.stderr.write(" file.zip@member.pkl\n")
  96. sys.stderr.write(" file.zip@*/pattern.*\n")
  97. sys.stderr.write(" (shell glob pattern for members)\n")
  98. sys.stderr.write(" (only first match will be shown)\n")
  99. return 2
  100. fname = argv[1]
  101. handle: Union[IO[bytes], BinaryIO]
  102. if "@" not in fname:
  103. with open(fname, "rb") as handle:
  104. DumpUnpickler.dump(handle, output_stream)
  105. else:
  106. zfname, mname = fname.split("@", 1)
  107. with zipfile.ZipFile(zfname) as zf:
  108. if "*" not in mname:
  109. with zf.open(mname) as handle:
  110. DumpUnpickler.dump(handle, output_stream)
  111. else:
  112. found = False
  113. for info in zf.infolist():
  114. if fnmatch.fnmatch(info.filename, mname):
  115. with zf.open(info) as handle:
  116. DumpUnpickler.dump(handle, output_stream)
  117. found = True
  118. break
  119. if not found:
  120. raise Exception(f"Could not find member matching {mname} in {zfname}")
  121. if __name__ == "__main__":
  122. # This hack works on every version of Python I've tested.
  123. # I've tested on the following versions:
  124. # 3.7.4
  125. if True:
  126. pprint.PrettyPrinter._dispatch[FakeObject.__repr__] = FakeObject.pp_format # type: ignore[attr-defined]
  127. sys.exit(main(sys.argv))