_weights_only_unpickler.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. # Unpickler restricted to loading only state dicts
  2. # Restrict constructing types to a list defined in _get_allowed_globals()
  3. # Restrict BUILD operation to `Tensor`, `Parameter` and `OrderedDict` types only
  4. # Restrict APPEND/APPENDS to `list`
  5. # In `GLOBALS` operation do not do class lookup by name, but rather rely on dictionary
  6. # defined by `_get_allowed_globals()` method, that contains:
  7. # - torch types (Storage, dtypes, Tensor, `torch.Size`),
  8. # - `torch._utils._rebuild` functions.
  9. # - `torch.nn.Parameter`
  10. # - `collections.OrderedDict`
  11. # Based of https://github.com/python/cpython/blob/main/Lib/pickle.py
  12. # Expected to be useful for loading PyTorch model weights
  13. # For example:
  14. # data = urllib.request.urlopen('https://download.pytorch.org/models/resnet50-0676ba61.pth').read()
  15. # buf = io.BytesIO(data)
  16. # weights = torch.load(buf, weights_only = True)
  17. import functools as _functools
  18. from collections import OrderedDict
  19. from pickle import (
  20. APPEND,
  21. APPENDS,
  22. BINFLOAT,
  23. BINGET,
  24. BININT,
  25. BININT1,
  26. BININT2,
  27. BINPERSID,
  28. BINPUT,
  29. BINUNICODE,
  30. BUILD,
  31. bytes_types,
  32. decode_long,
  33. EMPTY_DICT,
  34. EMPTY_LIST,
  35. EMPTY_SET,
  36. EMPTY_TUPLE,
  37. GLOBAL,
  38. LONG1,
  39. LONG_BINGET,
  40. LONG_BINPUT,
  41. MARK,
  42. NEWFALSE,
  43. NEWOBJ,
  44. NEWTRUE,
  45. NONE,
  46. PROTO,
  47. REDUCE,
  48. SETITEM,
  49. SETITEMS,
  50. SHORT_BINSTRING,
  51. STOP,
  52. TUPLE,
  53. TUPLE1,
  54. TUPLE2,
  55. TUPLE3,
  56. UnpicklingError,
  57. )
  58. from struct import unpack
  59. from sys import maxsize
  60. from typing import Any, Dict, List
  61. import torch
  62. # Unpickling machinery
  63. @_functools.lru_cache(maxsize=1)
  64. def _get_allowed_globals():
  65. rc: Dict[str, Any] = {
  66. "collections.OrderedDict": OrderedDict,
  67. "torch.nn.parameter.Parameter": torch.nn.Parameter,
  68. "torch.serialization._get_layout": torch.serialization._get_layout,
  69. "torch.Size": torch.Size,
  70. "torch.Tensor": torch.Tensor,
  71. }
  72. # dtype
  73. for t in [
  74. torch.complex32,
  75. torch.complex64,
  76. torch.complex128,
  77. torch.float16,
  78. torch.float32,
  79. torch.float64,
  80. torch.int8,
  81. torch.int16,
  82. torch.int32,
  83. torch.int64,
  84. ]:
  85. rc[str(t)] = t
  86. # Tensor classes
  87. for tt in torch._tensor_classes:
  88. rc[f"{tt.__module__}.{tt.__name__}"] = tt
  89. # Storage classes
  90. for ts in torch._storage_classes:
  91. rc[f"{ts.__module__}.{ts.__name__}"] = ts
  92. # Rebuild functions
  93. for f in [
  94. torch._utils._rebuild_parameter,
  95. torch._utils._rebuild_tensor,
  96. torch._utils._rebuild_tensor_v2,
  97. torch._utils._rebuild_sparse_tensor,
  98. torch._utils._rebuild_meta_tensor_no_storage,
  99. ]:
  100. rc[f"torch._utils.{f.__name__}"] = f
  101. # Handles Tensor Subclasses, Tensor's with attributes.
  102. # NOTE: It calls into above rebuild functions for regular Tensor types.
  103. rc["torch._tensor._rebuild_from_type_v2"] = torch._tensor._rebuild_from_type_v2
  104. return rc
  105. class Unpickler:
  106. def __init__(self, file, *, encoding: str = "bytes"):
  107. self.encoding = encoding
  108. self.readline = file.readline
  109. self.read = file.read
  110. self.memo: Dict[int, Any] = {}
  111. def load(self):
  112. """Read a pickled object representation from the open file.
  113. Return the reconstituted object hierarchy specified in the file.
  114. """
  115. self.metastack = []
  116. self.stack: List[Any] = []
  117. self.append = self.stack.append
  118. read = self.read
  119. readline = self.readline
  120. while True:
  121. key = read(1)
  122. if not key:
  123. raise EOFError
  124. assert isinstance(key, bytes_types)
  125. # Risky operators
  126. if key[0] == GLOBAL[0]:
  127. module = readline()[:-1].decode("utf-8")
  128. name = readline()[:-1].decode("utf-8")
  129. full_path = f"{module}.{name}"
  130. if full_path in _get_allowed_globals():
  131. self.append(_get_allowed_globals()[full_path])
  132. else:
  133. raise RuntimeError(f"Unsupported class {full_path}")
  134. elif key[0] == NEWOBJ[0]:
  135. args = self.stack.pop()
  136. cls = self.stack.pop()
  137. if cls is not torch.nn.Parameter:
  138. raise RuntimeError(f"Trying to instantiate unsupported class {cls}")
  139. self.append(torch.nn.Parameter(*args))
  140. elif key[0] == REDUCE[0]:
  141. args = self.stack.pop()
  142. func = self.stack[-1]
  143. if func not in _get_allowed_globals().values():
  144. raise RuntimeError(
  145. f"Trying to call reduce for unrecognized function {func}"
  146. )
  147. self.stack[-1] = func(*args)
  148. elif key[0] == BUILD[0]:
  149. state = self.stack.pop()
  150. inst = self.stack[-1]
  151. if type(inst) is torch.Tensor:
  152. # Legacy unpickling
  153. inst.set_(*state)
  154. elif type(inst) is torch.nn.Parameter:
  155. inst.__setstate__(state)
  156. elif type(inst) is OrderedDict:
  157. inst.__dict__.update(state)
  158. else:
  159. raise RuntimeError(
  160. f"Can only build Tensor, parameter or dict objects, but got {type(inst)}"
  161. )
  162. # Stack manipulation
  163. elif key[0] == APPEND[0]:
  164. item = self.stack.pop()
  165. list_obj = self.stack[-1]
  166. if type(list_obj) is not list:
  167. raise RuntimeError(
  168. f"Can only append to lists, but got {type(list_obj)}"
  169. )
  170. list_obj.append(item)
  171. elif key[0] == APPENDS[0]:
  172. items = self.pop_mark()
  173. list_obj = self.stack[-1]
  174. if type(list_obj) is not list:
  175. raise RuntimeError(
  176. f"Can only extend lists, but got {type(list_obj)}"
  177. )
  178. list_obj.extend(items)
  179. elif key[0] == SETITEM[0]:
  180. (v, k) = (self.stack.pop(), self.stack.pop())
  181. self.stack[-1][k] = v
  182. elif key[0] == SETITEMS[0]:
  183. items = self.pop_mark()
  184. for i in range(0, len(items), 2):
  185. self.stack[-1][items[i]] = items[i + 1]
  186. elif key[0] == MARK[0]:
  187. self.metastack.append(self.stack)
  188. self.stack = []
  189. self.append = self.stack.append
  190. elif key[0] == TUPLE[0]:
  191. items = self.pop_mark()
  192. self.append(tuple(items))
  193. elif key[0] == TUPLE1[0]:
  194. self.stack[-1] = (self.stack[-1],)
  195. elif key[0] == TUPLE2[0]:
  196. self.stack[-2:] = [(self.stack[-2], self.stack[-1])]
  197. elif key[0] == TUPLE3[0]:
  198. self.stack[-3:] = [(self.stack[-3], self.stack[-2], self.stack[-1])]
  199. # Basic types construction
  200. elif key[0] == NONE[0]:
  201. self.append(None)
  202. elif key[0] == NEWFALSE[0]:
  203. self.append(False)
  204. elif key[0] == NEWTRUE[0]:
  205. self.append(True)
  206. elif key[0] == EMPTY_TUPLE[0]:
  207. self.append(())
  208. elif key[0] == EMPTY_LIST[0]:
  209. self.append([])
  210. elif key[0] == EMPTY_DICT[0]:
  211. self.append({})
  212. elif key[0] == EMPTY_SET[0]:
  213. self.append(set())
  214. elif key[0] == BININT[0]:
  215. self.append(unpack("<i", read(4))[0])
  216. elif key[0] == BININT1[0]:
  217. self.append(self.read(1)[0])
  218. elif key[0] == BININT2[0]:
  219. self.append(unpack("<H", read(2))[0])
  220. elif key[0] == BINFLOAT[0]:
  221. self.append(unpack(">d", self.read(8))[0])
  222. elif key[0] == BINUNICODE[0]:
  223. strlen = unpack("<I", read(4))[0]
  224. if strlen > maxsize:
  225. raise RuntimeError("String is too long")
  226. strval = str(read(strlen), "utf-8", "surrogatepass")
  227. self.append(strval)
  228. elif key[0] == SHORT_BINSTRING[0]:
  229. strlen = read(1)[0]
  230. strdata = read(strlen)
  231. if self.encoding != "bytes":
  232. strdata = strdata.decode(self.encoding, "strict")
  233. self.append(strdata)
  234. elif key[0] == BINPERSID[0]:
  235. pid = self.stack.pop()
  236. # Only allow persistent load of storage
  237. if type(pid) is not tuple and not type(pid) is not int:
  238. raise RuntimeError(
  239. f"persistent_load id must be tuple or int, but got {type(pid)}"
  240. )
  241. if (
  242. type(pid) is tuple
  243. and len(pid) > 0
  244. and torch.serialization._maybe_decode_ascii(pid[0]) != "storage"
  245. ):
  246. raise RuntimeError(
  247. f"Only persistent_load of storage is allowed, but got {pid[0]}"
  248. )
  249. self.append(self.persistent_load(pid))
  250. elif key[0] in [BINGET[0], LONG_BINGET[0]]:
  251. idx = (read(1) if key[0] == BINGET[0] else unpack("<I", read(4)))[0]
  252. self.append(self.memo[idx])
  253. elif key[0] in [BINPUT[0], LONG_BINPUT[0]]:
  254. i = (read(1) if key[0] == BINPUT[0] else unpack("<I", read(4)))[0]
  255. if i < 0:
  256. raise ValueError("negative argument")
  257. self.memo[i] = self.stack[-1]
  258. elif key[0] == LONG1[0]:
  259. n = read(1)[0]
  260. data = read(n)
  261. self.append(decode_long(data))
  262. # First and last deserializer ops
  263. elif key[0] == PROTO[0]:
  264. # Read and ignore proto version
  265. read(1)[0]
  266. elif key[0] == STOP[0]:
  267. rc = self.stack.pop()
  268. return rc
  269. else:
  270. raise RuntimeError(f"Unsupported operand {key[0]}")
  271. # Return a list of items pushed in the stack after last MARK instruction.
  272. def pop_mark(self):
  273. items = self.stack
  274. self.stack = self.metastack.pop()
  275. self.append = self.stack.append
  276. return items
  277. def persistent_load(self, pid):
  278. raise UnpicklingError("unsupported persistent id encountered")
  279. def load(file, *, encoding: str = "ASCII"):
  280. return Unpickler(file, encoding=encoding).load()