| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294 | 
							- # Unpickler restricted to loading only state dicts
 
- # Restrict constructing types to a list defined in _get_allowed_globals()
 
- # Restrict BUILD operation to `Tensor`, `Parameter` and `OrderedDict` types only
 
- # Restrict APPEND/APPENDS to `list`
 
- # In `GLOBALS` operation do not do class lookup by name, but rather rely on dictionary
 
- # defined by `_get_allowed_globals()` method, that contains:
 
- # - torch types (Storage, dtypes, Tensor, `torch.Size`),
 
- # - `torch._utils._rebuild` functions.
 
- # - `torch.nn.Parameter`
 
- # - `collections.OrderedDict`
 
- # Based of https://github.com/python/cpython/blob/main/Lib/pickle.py
 
- # Expected to be useful for loading PyTorch model weights
 
- # For example:
 
- # data = urllib.request.urlopen('https://download.pytorch.org/models/resnet50-0676ba61.pth').read()
 
- # buf = io.BytesIO(data)
 
- # weights = torch.load(buf, weights_only = True)
 
- import functools as _functools
 
- from collections import OrderedDict
 
- from pickle import (
 
-     APPEND,
 
-     APPENDS,
 
-     BINFLOAT,
 
-     BINGET,
 
-     BININT,
 
-     BININT1,
 
-     BININT2,
 
-     BINPERSID,
 
-     BINPUT,
 
-     BINUNICODE,
 
-     BUILD,
 
-     bytes_types,
 
-     decode_long,
 
-     EMPTY_DICT,
 
-     EMPTY_LIST,
 
-     EMPTY_SET,
 
-     EMPTY_TUPLE,
 
-     GLOBAL,
 
-     LONG1,
 
-     LONG_BINGET,
 
-     LONG_BINPUT,
 
-     MARK,
 
-     NEWFALSE,
 
-     NEWOBJ,
 
-     NEWTRUE,
 
-     NONE,
 
-     PROTO,
 
-     REDUCE,
 
-     SETITEM,
 
-     SETITEMS,
 
-     SHORT_BINSTRING,
 
-     STOP,
 
-     TUPLE,
 
-     TUPLE1,
 
-     TUPLE2,
 
-     TUPLE3,
 
-     UnpicklingError,
 
- )
 
- from struct import unpack
 
- from sys import maxsize
 
- from typing import Any, Dict, List
 
- import torch
 
- # Unpickling machinery
 
- @_functools.lru_cache(maxsize=1)
 
- def _get_allowed_globals():
 
-     rc: Dict[str, Any] = {
 
-         "collections.OrderedDict": OrderedDict,
 
-         "torch.nn.parameter.Parameter": torch.nn.Parameter,
 
-         "torch.serialization._get_layout": torch.serialization._get_layout,
 
-         "torch.Size": torch.Size,
 
-         "torch.Tensor": torch.Tensor,
 
-     }
 
-     # dtype
 
-     for t in [
 
-         torch.complex32,
 
-         torch.complex64,
 
-         torch.complex128,
 
-         torch.float16,
 
-         torch.float32,
 
-         torch.float64,
 
-         torch.int8,
 
-         torch.int16,
 
-         torch.int32,
 
-         torch.int64,
 
-     ]:
 
-         rc[str(t)] = t
 
-     # Tensor classes
 
-     for tt in torch._tensor_classes:
 
-         rc[f"{tt.__module__}.{tt.__name__}"] = tt
 
-     # Storage classes
 
-     for ts in torch._storage_classes:
 
-         rc[f"{ts.__module__}.{ts.__name__}"] = ts
 
-     # Rebuild functions
 
-     for f in [
 
-         torch._utils._rebuild_parameter,
 
-         torch._utils._rebuild_tensor,
 
-         torch._utils._rebuild_tensor_v2,
 
-         torch._utils._rebuild_sparse_tensor,
 
-         torch._utils._rebuild_meta_tensor_no_storage,
 
-     ]:
 
-         rc[f"torch._utils.{f.__name__}"] = f
 
-     # Handles Tensor Subclasses, Tensor's with attributes.
 
-     # NOTE: It calls into above rebuild functions for regular Tensor types.
 
-     rc["torch._tensor._rebuild_from_type_v2"] = torch._tensor._rebuild_from_type_v2
 
-     return rc
 
- class Unpickler:
 
-     def __init__(self, file, *, encoding: str = "bytes"):
 
-         self.encoding = encoding
 
-         self.readline = file.readline
 
-         self.read = file.read
 
-         self.memo: Dict[int, Any] = {}
 
-     def load(self):
 
-         """Read a pickled object representation from the open file.
 
-         Return the reconstituted object hierarchy specified in the file.
 
-         """
 
-         self.metastack = []
 
-         self.stack: List[Any] = []
 
-         self.append = self.stack.append
 
-         read = self.read
 
-         readline = self.readline
 
-         while True:
 
-             key = read(1)
 
-             if not key:
 
-                 raise EOFError
 
-             assert isinstance(key, bytes_types)
 
-             # Risky operators
 
-             if key[0] == GLOBAL[0]:
 
-                 module = readline()[:-1].decode("utf-8")
 
-                 name = readline()[:-1].decode("utf-8")
 
-                 full_path = f"{module}.{name}"
 
-                 if full_path in _get_allowed_globals():
 
-                     self.append(_get_allowed_globals()[full_path])
 
-                 else:
 
-                     raise RuntimeError(f"Unsupported class {full_path}")
 
-             elif key[0] == NEWOBJ[0]:
 
-                 args = self.stack.pop()
 
-                 cls = self.stack.pop()
 
-                 if cls is not torch.nn.Parameter:
 
-                     raise RuntimeError(f"Trying to instantiate unsupported class {cls}")
 
-                 self.append(torch.nn.Parameter(*args))
 
-             elif key[0] == REDUCE[0]:
 
-                 args = self.stack.pop()
 
-                 func = self.stack[-1]
 
-                 if func not in _get_allowed_globals().values():
 
-                     raise RuntimeError(
 
-                         f"Trying to call reduce for unrecognized function {func}"
 
-                     )
 
-                 self.stack[-1] = func(*args)
 
-             elif key[0] == BUILD[0]:
 
-                 state = self.stack.pop()
 
-                 inst = self.stack[-1]
 
-                 if type(inst) is torch.Tensor:
 
-                     # Legacy unpickling
 
-                     inst.set_(*state)
 
-                 elif type(inst) is torch.nn.Parameter:
 
-                     inst.__setstate__(state)
 
-                 elif type(inst) is OrderedDict:
 
-                     inst.__dict__.update(state)
 
-                 else:
 
-                     raise RuntimeError(
 
-                         f"Can only build Tensor, parameter or dict objects, but got {type(inst)}"
 
-                     )
 
-             # Stack manipulation
 
-             elif key[0] == APPEND[0]:
 
-                 item = self.stack.pop()
 
-                 list_obj = self.stack[-1]
 
-                 if type(list_obj) is not list:
 
-                     raise RuntimeError(
 
-                         f"Can only append to lists, but got {type(list_obj)}"
 
-                     )
 
-                 list_obj.append(item)
 
-             elif key[0] == APPENDS[0]:
 
-                 items = self.pop_mark()
 
-                 list_obj = self.stack[-1]
 
-                 if type(list_obj) is not list:
 
-                     raise RuntimeError(
 
-                         f"Can only extend lists, but got {type(list_obj)}"
 
-                     )
 
-                 list_obj.extend(items)
 
-             elif key[0] == SETITEM[0]:
 
-                 (v, k) = (self.stack.pop(), self.stack.pop())
 
-                 self.stack[-1][k] = v
 
-             elif key[0] == SETITEMS[0]:
 
-                 items = self.pop_mark()
 
-                 for i in range(0, len(items), 2):
 
-                     self.stack[-1][items[i]] = items[i + 1]
 
-             elif key[0] == MARK[0]:
 
-                 self.metastack.append(self.stack)
 
-                 self.stack = []
 
-                 self.append = self.stack.append
 
-             elif key[0] == TUPLE[0]:
 
-                 items = self.pop_mark()
 
-                 self.append(tuple(items))
 
-             elif key[0] == TUPLE1[0]:
 
-                 self.stack[-1] = (self.stack[-1],)
 
-             elif key[0] == TUPLE2[0]:
 
-                 self.stack[-2:] = [(self.stack[-2], self.stack[-1])]
 
-             elif key[0] == TUPLE3[0]:
 
-                 self.stack[-3:] = [(self.stack[-3], self.stack[-2], self.stack[-1])]
 
-             # Basic types construction
 
-             elif key[0] == NONE[0]:
 
-                 self.append(None)
 
-             elif key[0] == NEWFALSE[0]:
 
-                 self.append(False)
 
-             elif key[0] == NEWTRUE[0]:
 
-                 self.append(True)
 
-             elif key[0] == EMPTY_TUPLE[0]:
 
-                 self.append(())
 
-             elif key[0] == EMPTY_LIST[0]:
 
-                 self.append([])
 
-             elif key[0] == EMPTY_DICT[0]:
 
-                 self.append({})
 
-             elif key[0] == EMPTY_SET[0]:
 
-                 self.append(set())
 
-             elif key[0] == BININT[0]:
 
-                 self.append(unpack("<i", read(4))[0])
 
-             elif key[0] == BININT1[0]:
 
-                 self.append(self.read(1)[0])
 
-             elif key[0] == BININT2[0]:
 
-                 self.append(unpack("<H", read(2))[0])
 
-             elif key[0] == BINFLOAT[0]:
 
-                 self.append(unpack(">d", self.read(8))[0])
 
-             elif key[0] == BINUNICODE[0]:
 
-                 strlen = unpack("<I", read(4))[0]
 
-                 if strlen > maxsize:
 
-                     raise RuntimeError("String is too long")
 
-                 strval = str(read(strlen), "utf-8", "surrogatepass")
 
-                 self.append(strval)
 
-             elif key[0] == SHORT_BINSTRING[0]:
 
-                 strlen = read(1)[0]
 
-                 strdata = read(strlen)
 
-                 if self.encoding != "bytes":
 
-                     strdata = strdata.decode(self.encoding, "strict")
 
-                 self.append(strdata)
 
-             elif key[0] == BINPERSID[0]:
 
-                 pid = self.stack.pop()
 
-                 # Only allow persistent load of storage
 
-                 if type(pid) is not tuple and not type(pid) is not int:
 
-                     raise RuntimeError(
 
-                         f"persistent_load id must be tuple or int, but got {type(pid)}"
 
-                     )
 
-                 if (
 
-                     type(pid) is tuple
 
-                     and len(pid) > 0
 
-                     and torch.serialization._maybe_decode_ascii(pid[0]) != "storage"
 
-                 ):
 
-                     raise RuntimeError(
 
-                         f"Only persistent_load of storage is allowed, but got {pid[0]}"
 
-                     )
 
-                 self.append(self.persistent_load(pid))
 
-             elif key[0] in [BINGET[0], LONG_BINGET[0]]:
 
-                 idx = (read(1) if key[0] == BINGET[0] else unpack("<I", read(4)))[0]
 
-                 self.append(self.memo[idx])
 
-             elif key[0] in [BINPUT[0], LONG_BINPUT[0]]:
 
-                 i = (read(1) if key[0] == BINPUT[0] else unpack("<I", read(4)))[0]
 
-                 if i < 0:
 
-                     raise ValueError("negative argument")
 
-                 self.memo[i] = self.stack[-1]
 
-             elif key[0] == LONG1[0]:
 
-                 n = read(1)[0]
 
-                 data = read(n)
 
-                 self.append(decode_long(data))
 
-             # First and last deserializer ops
 
-             elif key[0] == PROTO[0]:
 
-                 # Read and ignore proto version
 
-                 read(1)[0]
 
-             elif key[0] == STOP[0]:
 
-                 rc = self.stack.pop()
 
-                 return rc
 
-             else:
 
-                 raise RuntimeError(f"Unsupported operand {key[0]}")
 
-     # Return a list of items pushed in the stack after last MARK instruction.
 
-     def pop_mark(self):
 
-         items = self.stack
 
-         self.stack = self.metastack.pop()
 
-         self.append = self.stack.append
 
-         return items
 
-     def persistent_load(self, pid):
 
-         raise UnpicklingError("unsupported persistent id encountered")
 
- def load(file, *, encoding: str = "ASCII"):
 
-     return Unpickler(file, encoding=encoding).load()
 
 
  |