storage.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002
  1. import io
  2. import torch
  3. from ._utils import _type, _cuda
  4. from torch.types import Storage
  5. from typing import Any, TypeVar, Type, Union, cast
  6. import copy
  7. import collections
  8. from functools import lru_cache
  9. import warnings
  10. try:
  11. import numpy as np
  12. HAS_NUMPY = True
  13. except ModuleNotFoundError:
  14. np = None # type: ignore[assignment]
  15. T = TypeVar('T', bound='Union[_StorageBase, TypedStorage]')
  16. class _StorageBase:
  17. _cdata: Any
  18. is_sparse: bool = False
  19. is_sparse_csr: bool = False
  20. device: torch.device
  21. def __init__(self, *args, **kwargs): ... # noqa: E704
  22. def __len__(self) -> int: ... # noqa: E704
  23. def __getitem__(self, idx): ... # noqa: E704
  24. def copy_(self, source: T, non_blocking: bool = None) -> T: ... # noqa: E704
  25. def nbytes(self) -> int: ... # noqa: E704
  26. def size(self) -> int:
  27. return self.nbytes()
  28. def type(self, dtype: str = None, non_blocking: bool = False) -> T: ... # noqa: E704
  29. def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ... # noqa: E704
  30. def element_size(self) -> int: ... # noqa: E704
  31. def get_device(self) -> int: ... # noqa: E704
  32. def data_ptr(self) -> int: ... # noqa: E704
  33. # Defined in torch/csrc/generic/StorageSharing.cpp
  34. def _share_filename_cpu_(self, *args, **kwargs): ... # noqa: E704
  35. def _share_fd_cpu_(self, *args, **kwargs): ... # noqa: E704
  36. @classmethod
  37. def _new_using_filename_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704
  38. @classmethod
  39. def _new_using_fd_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704
  40. @classmethod
  41. def from_buffer(cls, *args, **kwargs) -> T: ... # noqa: E704
  42. @classmethod
  43. def _new_shared_filename_cpu(cls, manager, obj, size, *, device=None, dtype=None) -> T: ... # noqa: E704
  44. @classmethod
  45. def _release_ipc_counter_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704
  46. @classmethod
  47. def _new_with_weak_ptr(cls, *args, **kwargs) -> T: ... # noqa: E704
  48. def _shared_decref(self) -> T: ... # noqa: E704
  49. def _write_file(self, *args, **kwargs): ... # noqa: E704
  50. def resize_(self, size: int): ... # noqa: E704
  51. def _weak_ref(self, *args, **kwargs) -> T: ... # noqa: E704
  52. def is_pinned(self) -> bool: ... # noqa: E704
  53. def _set_from_file(self, *args, **kwargs): ... # noqa: E704
  54. def _set_cdata(self, *args, **kwargs): ... # noqa: E704
  55. def _share_cuda_(self, *args, **kwargs): ... # noqa: E704
  56. def is_shared(self) -> bool: ... # noqa: E704
  57. @classmethod
  58. def _new_shared_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704
  59. def _shared_incref(self, *args, **kwargs): ... # noqa: E704
  60. @classmethod
  61. def _free_weak_ref(cls, *args, **kwargs): ... # noqa: E704
  62. @property
  63. def is_cuda(self): ... # noqa: E704
  64. @classmethod
  65. def from_file(cls, filename, shared, nbytes) -> T: ... # noqa: E704
  66. @classmethod
  67. def _expired(cls, *args, **kwargs) -> T: ... # noqa: E704
  68. def __str__(self):
  69. info_str = (
  70. f'[{torch.typename(self)}(device={self.device}) '
  71. f'of size {len(self)}]')
  72. if self.device.type == 'meta':
  73. return '...\n' + info_str
  74. else:
  75. data_str = ' ' + '\n '.join(str(self[i]) for i in range(self.size()))
  76. return data_str + '\n' + info_str
  77. def __repr__(self):
  78. return str(self)
  79. def __iter__(self):
  80. return iter(map(lambda i: self[i], range(self.size())))
  81. def __copy__(self):
  82. return self.clone()
  83. def __deepcopy__(self, memo):
  84. memo = memo.setdefault('torch', {})
  85. if self._cdata in memo:
  86. return memo[self._cdata]
  87. new_storage = self.clone()
  88. memo[self._cdata] = new_storage
  89. return new_storage
  90. def __reduce__(self):
  91. b = io.BytesIO()
  92. torch.save(self, b, _use_new_zipfile_serialization=False)
  93. return (_load_from_bytes, (b.getvalue(),))
  94. def __sizeof__(self):
  95. return super().__sizeof__() + self.size()
  96. def clone(self):
  97. """Returns a copy of this storage"""
  98. return type(self)(self.nbytes(), device=self.device).copy_(self)
  99. def tolist(self):
  100. """Returns a list containing the elements of this storage"""
  101. return list(self)
  102. def cpu(self):
  103. """Returns a CPU copy of this storage if it's not already on the CPU"""
  104. if self.device.type != 'cpu':
  105. return torch.UntypedStorage(self.size()).copy_(self, False)
  106. else:
  107. return self
  108. def mps(self):
  109. """Returns a CPU copy of this storage if it's not already on the CPU"""
  110. if self.device.type != 'mps':
  111. return torch.UntypedStorage(self.size(), device="mps").copy_(self, False)
  112. else:
  113. return self
  114. def _to(self, dtype):
  115. if not isinstance(dtype, torch.dtype):
  116. raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
  117. storage = torch.tensor([], dtype=torch.uint8, device=self.device).set_(cast(Storage, self)).to(dtype)._typed_storage()
  118. if storage.data_ptr() == self.data_ptr():
  119. storage = storage.clone()
  120. return storage
  121. def double(self):
  122. """Casts this storage to double type"""
  123. return self._to(torch.double)
  124. def float(self):
  125. """Casts this storage to float type"""
  126. return self._to(torch.float)
  127. def half(self):
  128. """Casts this storage to half type"""
  129. return self._to(torch.half)
  130. def long(self):
  131. """Casts this storage to long type"""
  132. return self._to(torch.long)
  133. def int(self):
  134. """Casts this storage to int type"""
  135. return self._to(torch.int)
  136. def short(self):
  137. """Casts this storage to short type"""
  138. return self._to(torch.short)
  139. def char(self):
  140. """Casts this storage to char type"""
  141. return self._to(torch.int8)
  142. def byte(self):
  143. """Casts this storage to byte type"""
  144. return self._to(torch.uint8)
  145. def bool(self):
  146. """Casts this storage to bool type"""
  147. return self._to(torch.bool)
  148. def bfloat16(self):
  149. """Casts this storage to bfloat16 type"""
  150. return self._to(torch.bfloat16)
  151. def complex_double(self):
  152. """Casts this storage to complex double type"""
  153. return self._to(torch.cdouble)
  154. def complex_float(self):
  155. """Casts this storage to complex float type"""
  156. return self._to(torch.cfloat)
  157. def pin_memory(self):
  158. """Copies the storage to pinned memory, if it's not already pinned."""
  159. if self.is_cuda:
  160. raise TypeError(f"cannot pin '{self.type()}' only CPU memory can be pinned")
  161. import torch.cuda
  162. allocator = torch.cuda.memory._host_allocator() # type: ignore[attr-defined]
  163. return type(self)(self.size(), allocator=allocator).copy_(self)
  164. def share_memory_(self):
  165. """Moves the storage to shared memory.
  166. This is a no-op for storages already in shared memory and for CUDA
  167. storages, which do not need to be moved for sharing across processes.
  168. Storages in shared memory cannot be resized.
  169. Returns: self
  170. """
  171. from torch.multiprocessing import get_sharing_strategy
  172. if self.is_cuda:
  173. pass # CUDA doesn't use POSIX shared memory
  174. elif get_sharing_strategy() == 'file_system':
  175. self._share_filename_cpu_()
  176. else:
  177. self._share_fd_cpu_()
  178. return self
  179. @classmethod
  180. def _new_shared(cls, size, *, device='cpu'):
  181. """Creates a new storage in shared memory with the same data type"""
  182. from torch.multiprocessing import get_sharing_strategy
  183. device = torch.device(device)
  184. if device.type == 'cuda':
  185. return cls(size, device=device)
  186. elif get_sharing_strategy() == 'file_system':
  187. return cls._new_using_filename_cpu(size)
  188. else:
  189. return cls._new_using_fd_cpu(size)
  190. def untyped(self):
  191. return self
  192. class UntypedStorage(torch._C.StorageBase, _StorageBase):
  193. def __getitem__(self, *args, **kwargs):
  194. if self.device.type == 'meta':
  195. raise NotImplementedError("Not available for 'meta' device type")
  196. return super().__getitem__(*args, **kwargs)
  197. @property
  198. def is_cuda(self):
  199. return self.device.type == 'cuda'
  200. def _load_from_bytes(b):
  201. return torch.load(io.BytesIO(b))
  202. _StorageBase.type = _type # type: ignore[assignment]
  203. _StorageBase.cuda = _cuda # type: ignore[assignment]
  204. @lru_cache(maxsize=None)
  205. def _dtype_to_storage_type_map():
  206. # NOTE: We should no longer add dtypes to this map. This map
  207. # is only used for BC/FC with older PyTorch versions. Going forward,
  208. # new dtypes of TypedStorage should not translate to a legacy
  209. # <type>Storage class. Instead, new dtypes of TypedStorage should
  210. # be serialized as an UntypedStorage paired with a torch.dtype
  211. return {
  212. torch.double: 'DoubleStorage',
  213. torch.float: 'FloatStorage',
  214. torch.half: 'HalfStorage',
  215. torch.long: 'LongStorage',
  216. torch.int: 'IntStorage',
  217. torch.int16: 'ShortStorage',
  218. torch.int8: 'CharStorage',
  219. torch.uint8: 'ByteStorage',
  220. torch.bool: 'BoolStorage',
  221. torch.bfloat16: 'BFloat16Storage',
  222. torch.cdouble: 'ComplexDoubleStorage',
  223. torch.cfloat: 'ComplexFloatStorage',
  224. torch.qint8: 'QInt8Storage',
  225. torch.qint32: 'QInt32Storage',
  226. torch.quint8: 'QUInt8Storage',
  227. torch.quint4x2: 'QUInt4x2Storage',
  228. torch.quint2x4: 'QUInt2x4Storage',
  229. }
  230. @lru_cache(maxsize=None)
  231. def _storage_type_to_dtype_map():
  232. dtype_map = {
  233. val: key for key, val in _dtype_to_storage_type_map().items()}
  234. return dtype_map
  235. def _get_storage_from_sequence(sequence, dtype, device):
  236. if dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
  237. interpret_dtypes = {
  238. torch.quint8: torch.uint8,
  239. torch.quint4x2: torch.uint8,
  240. torch.quint2x4: torch.uint8,
  241. torch.qint32: torch.int32,
  242. torch.qint8: torch.int8
  243. }
  244. tmp_tensor = torch.tensor(
  245. sequence,
  246. dtype=interpret_dtypes[dtype],
  247. device=device)
  248. else:
  249. tmp_tensor = torch.tensor(
  250. sequence,
  251. dtype=dtype,
  252. device=device)
  253. return tmp_tensor._typed_storage()._untyped_storage
  254. def _isint(x):
  255. if HAS_NUMPY:
  256. return isinstance(x, (int, np.integer))
  257. else:
  258. return isinstance(x, int)
  259. def _warn_typed_storage_removal(stacklevel=2):
  260. message = (
  261. "TypedStorage is deprecated. It will be removed in the future and "
  262. "UntypedStorage will be the only storage class. This should only matter "
  263. "to you if you are using storages directly. To access UntypedStorage "
  264. "directly, use tensor.untyped_storage() instead of tensor.storage()"
  265. )
  266. warnings.warn(message, UserWarning, stacklevel=stacklevel + 1)
  267. class TypedStorage:
  268. is_sparse = False
  269. dtype: torch.dtype
  270. @property
  271. def _dtype(self):
  272. return self.dtype
  273. def fill_(self, value):
  274. _warn_typed_storage_removal()
  275. self._setitem(slice(0, self._size()), value)
  276. return self
  277. def __new__(cls, *args, wrap_storage=None, dtype=None, device=None, _internal=False):
  278. if not _internal:
  279. _warn_typed_storage_removal()
  280. if cls == torch.storage._LegacyStorage:
  281. raise RuntimeError("Only child classes of _LegacyStorage can be instantiated")
  282. if cls == TypedStorage:
  283. return super().__new__(cls)
  284. else:
  285. arg_error_msg = (
  286. f'{cls}.__new__ received an invalid combination '
  287. f'of arguments. Expected one of:\n'
  288. ' * no arguments\n'
  289. ' * (int size)\n'
  290. ' * (Sequence data)\n'
  291. ' * (*, UntypedStorage wrap_storage)')
  292. if device is not None:
  293. raise RuntimeError(
  294. arg_error_msg +
  295. "\nKeyword argument 'device' cannot be specified")
  296. if dtype is not None:
  297. raise RuntimeError(
  298. arg_error_msg +
  299. "\nKeyword argument 'dtype' cannot be specified")
  300. if wrap_storage is None:
  301. if len(args) > 1:
  302. raise RuntimeError(
  303. arg_error_msg +
  304. "\nToo many positional arguments")
  305. if len(args) == 1 and not _isint(args[0]) and not isinstance(args[0], collections.abc.Sequence):
  306. raise TypeError(
  307. arg_error_msg +
  308. f"\nArgument type not recognized: {type(args[0])}")
  309. return TypedStorage(
  310. *args,
  311. dtype=cls._dtype,
  312. device='cuda' if cls.__module__ == 'torch.cuda' else 'cpu',
  313. _internal=True)
  314. else:
  315. if len(args) != 0:
  316. raise RuntimeError(
  317. arg_error_msg +
  318. "\nNo positional arguments should be given when using "
  319. "'wrap_storage'")
  320. if not isinstance(wrap_storage, torch.UntypedStorage):
  321. raise TypeError(
  322. arg_error_msg +
  323. f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}")
  324. cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
  325. if wrap_storage.device.type != cls_device:
  326. raise RuntimeError(
  327. arg_error_msg +
  328. f"\nDevice of 'wrap_storage' must be {cls_device}"
  329. f", but got {wrap_storage.device.type}")
  330. return TypedStorage(
  331. *args,
  332. wrap_storage=wrap_storage,
  333. dtype=cls.dtype,
  334. _internal=True)
  335. def __init__(self, *args, device=None, dtype=None, wrap_storage=None, _internal=False):
  336. if not _internal:
  337. _warn_typed_storage_removal()
  338. arg_error_msg = (
  339. 'TypedStorage.__init__ received an invalid combination '
  340. 'of arguments. Expected one of:\n'
  341. ' * (*, torch.device device, torch.dtype dtype)\n'
  342. ' * (int size, *, torch.device device, torch.dtype dtype)\n'
  343. ' * (Sequence data, *, torch.device device, torch.dtype dtype)\n'
  344. ' * (*, UntypedStorage wrap_storage, torch.dtype dtype)')
  345. if wrap_storage is not None:
  346. if len(args) != 0:
  347. raise RuntimeError(
  348. arg_error_msg +
  349. "\nNo positional arguments should be given when using "
  350. "'wrap_storage'")
  351. if dtype is None:
  352. raise RuntimeError(
  353. arg_error_msg +
  354. "\nArgument 'dtype' must be specified")
  355. if not isinstance(dtype, torch.dtype):
  356. raise TypeError(
  357. arg_error_msg +
  358. f"\nArgument 'dtype' must be torch.dtype, not {type(dtype)}")
  359. if device is not None:
  360. raise RuntimeError(
  361. arg_error_msg +
  362. "\nArgument 'device' should not be specified when 'wrap_storage' is given")
  363. self.dtype = dtype
  364. if not isinstance(wrap_storage, torch.UntypedStorage):
  365. raise TypeError(
  366. arg_error_msg +
  367. f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}")
  368. self._untyped_storage = wrap_storage
  369. else:
  370. self.dtype = torch.get_default_dtype() if dtype is None else dtype
  371. device = torch.device('cpu' if device is None else device)
  372. if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
  373. if device.type == 'cuda':
  374. raise RuntimeError("Cannot create CUDA storage with quantized dtype")
  375. if len(args) == 0:
  376. self._untyped_storage = torch.UntypedStorage(device=device)
  377. elif len(args) == 1:
  378. if _isint(args[0]):
  379. self._untyped_storage = torch.UntypedStorage(int(args[0]) * self._element_size(), device=device)
  380. elif isinstance(args[0], collections.abc.Sequence):
  381. self._untyped_storage = _get_storage_from_sequence(args[0], self.dtype, device)
  382. else:
  383. raise TypeError(
  384. arg_error_msg +
  385. f"\nArgument type not recognized: {type(args[0])}")
  386. else:
  387. raise RuntimeError(
  388. arg_error_msg +
  389. "\nToo many positional arguments")
  390. @property
  391. def is_cuda(self):
  392. _warn_typed_storage_removal()
  393. return self._untyped_storage.device.type == 'cuda'
  394. def untyped(self):
  395. """Returns the internal :class:`torch.UntypedStorage`"""
  396. _warn_typed_storage_removal()
  397. return self._untyped_storage
  398. def _new_wrapped_storage(self, untyped_storage):
  399. assert type(untyped_storage) == torch.UntypedStorage
  400. if type(self) == TypedStorage:
  401. return TypedStorage(
  402. wrap_storage=untyped_storage,
  403. dtype=self.dtype,
  404. _internal=True)
  405. else:
  406. return type(self)(wrap_storage=untyped_storage)
  407. def __len__(self):
  408. _warn_typed_storage_removal()
  409. return self._size()
  410. def _maybe_wrap_index(self, idx, is_stop=False):
  411. if idx is None:
  412. if is_stop:
  413. return self._size()
  414. else:
  415. return 0
  416. else:
  417. if type(idx) != int:
  418. raise TypeError(
  419. f"can't index a {type(self)} with {type(idx)}")
  420. if is_stop:
  421. if (idx > self._size()) or (idx < -self._size()):
  422. raise IndexError(
  423. f'index {idx} out of range for storage of size {self.size()}')
  424. if idx > 0:
  425. return idx
  426. else:
  427. return idx % self._size()
  428. else:
  429. if (idx >= self._size()) or (idx < -self._size()):
  430. raise IndexError(
  431. f'index {idx} out of range for storage of size {self.size()}')
  432. return idx % self._size()
  433. def __setitem__(self, idx, value):
  434. _warn_typed_storage_removal()
  435. return self._setitem(idx, value)
  436. def _setitem(self, idx, value):
  437. if not isinstance(idx, (int, slice)):
  438. raise RuntimeError(f"can't index a {type(self)} with {type(idx)}")
  439. if torch.is_storage(value):
  440. raise RuntimeError(f'cannot set item with value type {type(value)}')
  441. if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
  442. interpret_dtypes = {
  443. torch.quint8: torch.uint8,
  444. torch.quint4x2: torch.uint8,
  445. torch.quint2x4: torch.uint8,
  446. torch.qint32: torch.int32,
  447. torch.qint8: torch.int8
  448. }
  449. tmp_dtype = interpret_dtypes[self.dtype]
  450. tmp_tensor = torch.tensor([], dtype=tmp_dtype, device=self._untyped_storage.device)
  451. tmp_tensor.set_(TypedStorage(
  452. wrap_storage=self._untyped_storage,
  453. dtype=tmp_dtype,
  454. _internal=True))
  455. else:
  456. tmp_tensor = torch.tensor([], dtype=self.dtype, device=self._untyped_storage.device).set_(self)
  457. tmp_tensor[idx] = value
  458. def __getitem__(self, idx):
  459. _warn_typed_storage_removal()
  460. return self._getitem(idx)
  461. def _getitem(self, idx):
  462. if self._untyped_storage.device.type == 'meta':
  463. raise NotImplementedError("Not available for 'meta' device type")
  464. # NOTE: Before TypedStorage existed, indexing with a slice used to be
  465. # possible for <type>Storage objects. However, it would return
  466. # a storage view, which would be a hassle to implement in TypedStorage,
  467. # so it was disabled
  468. if isinstance(idx, slice):
  469. raise RuntimeError('slices are only supported in UntypedStorage.__getitem__')
  470. elif not isinstance(idx, int):
  471. raise RuntimeError(f"can't index a {type(self)} with {type(idx)}")
  472. if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
  473. interpret_dtypes = {
  474. torch.quint8: torch.uint8,
  475. torch.quint4x2: torch.uint8,
  476. torch.quint2x4: torch.uint8,
  477. torch.qint32: torch.int32,
  478. torch.qint8: torch.int8
  479. }
  480. return TypedStorage(
  481. wrap_storage=self._untyped_storage,
  482. dtype=interpret_dtypes[self.dtype],
  483. _internal=True)._getitem(idx)
  484. idx_wrapped = self._maybe_wrap_index(idx)
  485. tmp_tensor = torch.tensor([], dtype=self.dtype, device=self._untyped_storage.device).set_(self)
  486. return tmp_tensor[idx_wrapped].item()
  487. def copy_(self, source: T, non_blocking: bool = None):
  488. _warn_typed_storage_removal()
  489. if isinstance(source, TypedStorage):
  490. self._untyped_storage.copy_(source._untyped_storage, non_blocking)
  491. else:
  492. self._untyped_storage.copy_(source, non_blocking)
  493. return self
  494. def nbytes(self):
  495. _warn_typed_storage_removal()
  496. return self._nbytes()
  497. # For internal use only, to avoid deprecation warning
  498. def _nbytes(self):
  499. return self._untyped_storage.nbytes()
  500. def type(self, dtype: str = None, non_blocking: bool = False) -> Union[T, str]:
  501. _warn_typed_storage_removal()
  502. if dtype is None:
  503. legacy_class = self._get_legacy_storage_class()
  504. if legacy_class is not None:
  505. return legacy_class.__module__ + '.' + legacy_class.__name__
  506. return '.'.join([self.__module__, type(self).__name__])
  507. else:
  508. return self._untyped_storage.type(dtype, non_blocking)
  509. def cuda(self, device=None, non_blocking=False, **kwargs) -> T:
  510. _warn_typed_storage_removal()
  511. if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
  512. raise RuntimeError("Cannot create CUDA storage with quantized dtype")
  513. cuda_storage: torch.UntypedStorage = self._untyped_storage.cuda(device, non_blocking, **kwargs)
  514. return self._new_wrapped_storage(cuda_storage)
  515. def element_size(self):
  516. _warn_typed_storage_removal()
  517. return self._element_size()
  518. # For internal use only, to avoid deprecation warning
  519. def _element_size(self):
  520. return torch._utils._element_size(self.dtype)
  521. def get_device(self) -> int:
  522. _warn_typed_storage_removal()
  523. return self._untyped_storage.get_device()
  524. def __str__(self):
  525. _warn_typed_storage_removal()
  526. info_str = (
  527. f'[{torch.typename(self)}(dtype={self.dtype}, '
  528. f'device={self.device}) of size {len(self)}]')
  529. if self.device.type == 'meta':
  530. return '...\n' + info_str
  531. else:
  532. data_str = ' ' + '\n '.join(str(self[i]) for i in range(self.size()))
  533. return data_str + '\n' + info_str
  534. def __repr__(self):
  535. _warn_typed_storage_removal()
  536. return str(self)
  537. def __iter__(self):
  538. _warn_typed_storage_removal()
  539. return iter(map(lambda i: self[i], range(self.size())))
  540. def __copy__(self):
  541. _warn_typed_storage_removal()
  542. return self._new_wrapped_storage(copy.copy(self._untyped_storage))
  543. def __deepcopy__(self, memo):
  544. _warn_typed_storage_removal()
  545. return self._deepcopy(memo)
  546. # For internal use only, to avoid deprecation warning
  547. def _deepcopy(self, memo):
  548. return self._new_wrapped_storage(copy.deepcopy(self._untyped_storage, memo))
  549. def __sizeof__(self):
  550. _warn_typed_storage_removal()
  551. return super().__sizeof__() + self.nbytes()
  552. def clone(self):
  553. """Returns a copy of this storage"""
  554. _warn_typed_storage_removal()
  555. return self._new_wrapped_storage(self._untyped_storage.clone())
  556. def tolist(self):
  557. """Returns a list containing the elements of this storage"""
  558. _warn_typed_storage_removal()
  559. return list(self)
  560. def cpu(self):
  561. """Returns a CPU copy of this storage if it's not already on the CPU"""
  562. _warn_typed_storage_removal()
  563. return self._new_wrapped_storage(self._untyped_storage.cpu())
  564. def pin_memory(self):
  565. """Coppies the storage to pinned memory, if it's not already pinned."""
  566. _warn_typed_storage_removal()
  567. return self._new_wrapped_storage(self._untyped_storage.pin_memory())
  568. def share_memory_(self):
  569. """Moves the storage to shared memory.
  570. This is a no-op for storages already in shared memory and for CUDA
  571. storages, which do not need to be moved for sharing across processes.
  572. Storages in shared memory cannot be resized.
  573. Returns: self
  574. """
  575. _warn_typed_storage_removal()
  576. return self._share_memory_()
  577. # For internal use only, to avoid deprecation warning
  578. def _share_memory_(self):
  579. self._untyped_storage.share_memory_()
  580. return self
  581. def _new_shared(self, size, *, device=None):
  582. """Creates a new storage in shared memory with the same data type"""
  583. if device is None:
  584. device = 'cpu'
  585. device = torch.device(device)
  586. untyped_storage = torch.UntypedStorage._new_shared(size * self._element_size(), device=device)
  587. return TypedStorage(
  588. wrap_storage=untyped_storage,
  589. dtype=self.dtype,
  590. _internal=True)
  591. @property
  592. def _cdata(self):
  593. return self._untyped_storage._cdata
  594. @property
  595. def device(self):
  596. _warn_typed_storage_removal()
  597. return self._untyped_storage.device
  598. def size(self):
  599. _warn_typed_storage_removal()
  600. return self._size()
  601. # For internal use only, to avoid deprecation warning
  602. def _size(self):
  603. # NB: don't indirect through __len__, as that requires
  604. # an int to be returned
  605. return self._untyped_storage.nbytes() // self._element_size()
  606. def pickle_storage_type(self):
  607. _warn_typed_storage_removal()
  608. return self._pickle_storage_type()
  609. # For internal use only, to avoid deprecation warning
  610. def _pickle_storage_type(self):
  611. try:
  612. return _dtype_to_storage_type_map()[self.dtype]
  613. except KeyError as e:
  614. raise KeyError(f'dtype {self.dtype} is not recognized') from e
  615. def __reduce__(self):
  616. b = io.BytesIO()
  617. torch.save(self, b, _use_new_zipfile_serialization=False)
  618. return (_load_from_bytes, (b.getvalue(),))
  619. def data_ptr(self):
  620. _warn_typed_storage_removal()
  621. return self._data_ptr()
  622. # For internal use only, to avoid deprecation warning
  623. def _data_ptr(self):
  624. return self._untyped_storage.data_ptr()
  625. def resize_(self, size):
  626. _warn_typed_storage_removal()
  627. self._resize_(size)
  628. # For internal use only, to avoid deprecation warning
  629. def _resize_(self, size):
  630. self._untyped_storage.resize_(size * self._element_size())
  631. @classmethod
  632. def _free_weak_ref(cls, *args, **kwargs):
  633. return UntypedStorage._free_weak_ref(*args, **kwargs)
  634. def _weak_ref(self, *args, **kwargs):
  635. return self._untyped_storage._weak_ref(*args, **kwargs)
  636. @classmethod
  637. def from_buffer(cls, *args, **kwargs):
  638. _warn_typed_storage_removal()
  639. return cls._from_buffer(*args, **kwargs)
  640. @classmethod
  641. def _from_buffer(cls, *args, dtype=None, device=None, **kwargs):
  642. if cls == TypedStorage:
  643. dtype = torch.get_default_dtype() if dtype is None else dtype
  644. device = torch.device('cpu' if device is None else device)
  645. if device.type != 'cpu':
  646. raise RuntimeError(f'TypedStorage.from_buffer: Not available for device {device.type}')
  647. untyped_storage: torch.UntypedStorage = torch.UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
  648. else:
  649. if dtype is not None or len(args) == 5:
  650. raise RuntimeError((
  651. "from_buffer: 'dtype' can only be specified in "
  652. "UntypedStorage.from_buffer and TypedStorage.from_buffer"))
  653. if device is not None:
  654. raise RuntimeError((
  655. "from_buffer: 'device' can only be specified in "
  656. "UntypedStorage.from_buffer and TypedStorage.from_buffer"))
  657. dtype = cls._dtype
  658. untyped_storage = torch.UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
  659. return TypedStorage(
  660. wrap_storage=untyped_storage,
  661. dtype=dtype,
  662. _internal=True)
  663. def _to(self, dtype):
  664. if not isinstance(dtype, torch.dtype):
  665. raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
  666. storage = torch.tensor([], dtype=self.dtype, device=self.device).set_(self).to(dtype)._typed_storage()
  667. if storage.data_ptr() == self.data_ptr():
  668. storage = storage.clone()
  669. return storage
  670. def double(self):
  671. """Casts this storage to double type"""
  672. _warn_typed_storage_removal()
  673. return self._to(torch.double)
  674. def float(self):
  675. """Casts this storage to float type"""
  676. _warn_typed_storage_removal()
  677. return self._to(torch.float)
  678. def half(self):
  679. """Casts this storage to half type"""
  680. _warn_typed_storage_removal()
  681. return self._to(torch.half)
  682. def long(self):
  683. """Casts this storage to long type"""
  684. _warn_typed_storage_removal()
  685. return self._to(torch.long)
  686. def int(self):
  687. """Casts this storage to int type"""
  688. _warn_typed_storage_removal()
  689. return self._to(torch.int)
  690. def short(self):
  691. """Casts this storage to short type"""
  692. _warn_typed_storage_removal()
  693. return self._to(torch.short)
  694. def char(self):
  695. """Casts this storage to char type"""
  696. _warn_typed_storage_removal()
  697. return self._to(torch.int8)
  698. def byte(self):
  699. """Casts this storage to byte type"""
  700. _warn_typed_storage_removal()
  701. return self._to(torch.uint8)
  702. def bool(self):
  703. """Casts this storage to bool type"""
  704. _warn_typed_storage_removal()
  705. return self._to(torch.bool)
  706. def bfloat16(self):
  707. """Casts this storage to bfloat16 type"""
  708. _warn_typed_storage_removal()
  709. return self._to(torch.bfloat16)
  710. def complex_double(self):
  711. """Casts this storage to complex double type"""
  712. _warn_typed_storage_removal()
  713. return self._to(torch.cdouble)
  714. def complex_float(self):
  715. """Casts this storage to complex float type"""
  716. _warn_typed_storage_removal()
  717. return self._to(torch.cfloat)
  718. @classmethod
  719. def from_file(cls, filename, shared, size):
  720. """
  721. from_file(filename, shared=False, size=0) -> Storage
  722. If `shared` is `True`, then memory is shared between all processes.
  723. All changes are written to the file. If `shared` is `False`, then the changes on
  724. the storage do not affect the file.
  725. `size` is the number of elements in the storage. If `shared` is `False`,
  726. then the file must contain at least `size * sizeof(Type)` bytes
  727. (`Type` is the type of storage). If `shared` is `True` the file will be
  728. created if needed.
  729. Args:
  730. filename (str): file name to map
  731. shared (bool): whether to share memory
  732. size (int): number of elements in the storage
  733. """
  734. _warn_typed_storage_removal()
  735. if cls == TypedStorage:
  736. raise RuntimeError('from_file can only be called on derived classes')
  737. untyped_storage: UntypedStorage = UntypedStorage.from_file(
  738. filename,
  739. shared,
  740. size * torch._utils._element_size(cls.dtype))
  741. storage = cls(wrap_storage=untyped_storage)
  742. return storage
  743. @classmethod
  744. def _expired(cls, *args, **kwargs):
  745. return UntypedStorage._expired(*args, **kwargs)
  746. def is_pinned(self):
  747. _warn_typed_storage_removal()
  748. return self._untyped_storage.is_pinned()
  749. def _write_file(self, *args, **kwargs):
  750. return self._untyped_storage._write_file(*args, **kwargs)
  751. def _set_from_file(self, *args, **kwargs):
  752. return self._untyped_storage._set_from_file(*args, **kwargs)
  753. def _set_cdata(self, *args, **kwargs):
  754. return self._untyped_storage._set_cdata(*args, **kwargs)
  755. def _share_cuda_(self, *args, **kwargs):
  756. return self._untyped_storage._share_cuda_(*args, **kwargs)
  757. def is_shared(self):
  758. _warn_typed_storage_removal()
  759. return self._is_shared()
  760. # For internal use only, to avoid deprecation warning
  761. def _is_shared(self):
  762. return self._untyped_storage.is_shared()
  763. @classmethod
  764. def _new_shared_cuda(cls, *args, **kwargs):
  765. return torch.UntypedStorage._new_shared_cuda(*args, **kwargs)
  766. def _share_filename_cpu_(self, *args, **kwargs):
  767. manager_handle, storage_handle, size = self._untyped_storage._share_filename_cpu_(*args, **kwargs)
  768. return manager_handle, storage_handle, size // self._element_size()
  769. def _shared_decref(self):
  770. self._untyped_storage._shared_decref()
  771. return self
  772. @classmethod
  773. def _release_ipc_counter(cls, *args, device=None, **kwargs):
  774. return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
  775. def _shared_incref(self, *args, **kwargs):
  776. return self._untyped_storage._shared_incref(*args, **kwargs)
  777. def _share_fd_cpu_(self, *args, **kwargs):
  778. fd, size = self._untyped_storage._share_fd_cpu_(*args, **kwargs)
  779. return fd, size // self._element_size()
  780. def _get_legacy_storage_class(self):
  781. if self.dtype not in _dtype_to_storage_type_map():
  782. return None
  783. storage_name = _dtype_to_storage_type_map()[self.dtype]
  784. if self.device.type not in ['cpu', 'cuda']:
  785. return None
  786. module = torch if self.device.type == 'cpu' else torch.cuda
  787. try:
  788. return getattr(module, storage_name)
  789. except AttributeError:
  790. return None
  791. TypedStorage.type.__doc__ = _type.__doc__
  792. TypedStorage.cuda.__doc__ = _cuda.__doc__
  793. class _LegacyStorageMeta(type):
  794. dtype: torch.dtype
  795. def __instancecheck__(cls, instance):
  796. if type(instance) == TypedStorage:
  797. cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
  798. return (cls_device == instance.device.type) and (cls.dtype == instance.dtype)
  799. return False
  800. class _LegacyStorage(TypedStorage, metaclass=_LegacyStorageMeta):
  801. @classmethod
  802. def _new_shared(cls, size):
  803. """Creates a new storage in shared memory with the same data type"""
  804. untyped_storage = torch.UntypedStorage._new_shared(size * cls()._element_size())
  805. return cls(wrap_storage=untyped_storage)
  806. @classmethod
  807. def _release_ipc_counter(cls, *args, **kwargs):
  808. return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
  809. @classmethod
  810. def _new_shared_filename(cls, manager, obj, size):
  811. bytes_size = size * torch._utils._element_size(cls.dtype)
  812. return cls(wrap_storage=torch.UntypedStorage._new_shared_filename_cpu(manager, obj, bytes_size))
  813. def _get_dtype_from_pickle_storage_type(pickle_storage_type: str):
  814. try:
  815. return _storage_type_to_dtype_map()[pickle_storage_type]
  816. except KeyError as e:
  817. raise KeyError(
  818. f'pickle storage type "{pickle_storage_type}" is not recognized') from e