123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002 |
- import io
- import torch
- from ._utils import _type, _cuda
- from torch.types import Storage
- from typing import Any, TypeVar, Type, Union, cast
- import copy
- import collections
- from functools import lru_cache
- import warnings
- try:
- import numpy as np
- HAS_NUMPY = True
- except ModuleNotFoundError:
- np = None # type: ignore[assignment]
- T = TypeVar('T', bound='Union[_StorageBase, TypedStorage]')
- class _StorageBase:
- _cdata: Any
- is_sparse: bool = False
- is_sparse_csr: bool = False
- device: torch.device
- def __init__(self, *args, **kwargs): ... # noqa: E704
- def __len__(self) -> int: ... # noqa: E704
- def __getitem__(self, idx): ... # noqa: E704
- def copy_(self, source: T, non_blocking: bool = None) -> T: ... # noqa: E704
- def nbytes(self) -> int: ... # noqa: E704
- def size(self) -> int:
- return self.nbytes()
- def type(self, dtype: str = None, non_blocking: bool = False) -> T: ... # noqa: E704
- def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ... # noqa: E704
- def element_size(self) -> int: ... # noqa: E704
- def get_device(self) -> int: ... # noqa: E704
- def data_ptr(self) -> int: ... # noqa: E704
- # Defined in torch/csrc/generic/StorageSharing.cpp
- def _share_filename_cpu_(self, *args, **kwargs): ... # noqa: E704
- def _share_fd_cpu_(self, *args, **kwargs): ... # noqa: E704
- @classmethod
- def _new_using_filename_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704
- @classmethod
- def _new_using_fd_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704
- @classmethod
- def from_buffer(cls, *args, **kwargs) -> T: ... # noqa: E704
- @classmethod
- def _new_shared_filename_cpu(cls, manager, obj, size, *, device=None, dtype=None) -> T: ... # noqa: E704
- @classmethod
- def _release_ipc_counter_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704
- @classmethod
- def _new_with_weak_ptr(cls, *args, **kwargs) -> T: ... # noqa: E704
- def _shared_decref(self) -> T: ... # noqa: E704
- def _write_file(self, *args, **kwargs): ... # noqa: E704
- def resize_(self, size: int): ... # noqa: E704
- def _weak_ref(self, *args, **kwargs) -> T: ... # noqa: E704
- def is_pinned(self) -> bool: ... # noqa: E704
- def _set_from_file(self, *args, **kwargs): ... # noqa: E704
- def _set_cdata(self, *args, **kwargs): ... # noqa: E704
- def _share_cuda_(self, *args, **kwargs): ... # noqa: E704
- def is_shared(self) -> bool: ... # noqa: E704
- @classmethod
- def _new_shared_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704
- def _shared_incref(self, *args, **kwargs): ... # noqa: E704
- @classmethod
- def _free_weak_ref(cls, *args, **kwargs): ... # noqa: E704
- @property
- def is_cuda(self): ... # noqa: E704
- @classmethod
- def from_file(cls, filename, shared, nbytes) -> T: ... # noqa: E704
- @classmethod
- def _expired(cls, *args, **kwargs) -> T: ... # noqa: E704
- def __str__(self):
- info_str = (
- f'[{torch.typename(self)}(device={self.device}) '
- f'of size {len(self)}]')
- if self.device.type == 'meta':
- return '...\n' + info_str
- else:
- data_str = ' ' + '\n '.join(str(self[i]) for i in range(self.size()))
- return data_str + '\n' + info_str
- def __repr__(self):
- return str(self)
- def __iter__(self):
- return iter(map(lambda i: self[i], range(self.size())))
- def __copy__(self):
- return self.clone()
- def __deepcopy__(self, memo):
- memo = memo.setdefault('torch', {})
- if self._cdata in memo:
- return memo[self._cdata]
- new_storage = self.clone()
- memo[self._cdata] = new_storage
- return new_storage
- def __reduce__(self):
- b = io.BytesIO()
- torch.save(self, b, _use_new_zipfile_serialization=False)
- return (_load_from_bytes, (b.getvalue(),))
- def __sizeof__(self):
- return super().__sizeof__() + self.size()
- def clone(self):
- """Returns a copy of this storage"""
- return type(self)(self.nbytes(), device=self.device).copy_(self)
- def tolist(self):
- """Returns a list containing the elements of this storage"""
- return list(self)
- def cpu(self):
- """Returns a CPU copy of this storage if it's not already on the CPU"""
- if self.device.type != 'cpu':
- return torch.UntypedStorage(self.size()).copy_(self, False)
- else:
- return self
- def mps(self):
- """Returns a CPU copy of this storage if it's not already on the CPU"""
- if self.device.type != 'mps':
- return torch.UntypedStorage(self.size(), device="mps").copy_(self, False)
- else:
- return self
- def _to(self, dtype):
- if not isinstance(dtype, torch.dtype):
- raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
- storage = torch.tensor([], dtype=torch.uint8, device=self.device).set_(cast(Storage, self)).to(dtype)._typed_storage()
- if storage.data_ptr() == self.data_ptr():
- storage = storage.clone()
- return storage
- def double(self):
- """Casts this storage to double type"""
- return self._to(torch.double)
- def float(self):
- """Casts this storage to float type"""
- return self._to(torch.float)
- def half(self):
- """Casts this storage to half type"""
- return self._to(torch.half)
- def long(self):
- """Casts this storage to long type"""
- return self._to(torch.long)
- def int(self):
- """Casts this storage to int type"""
- return self._to(torch.int)
- def short(self):
- """Casts this storage to short type"""
- return self._to(torch.short)
- def char(self):
- """Casts this storage to char type"""
- return self._to(torch.int8)
- def byte(self):
- """Casts this storage to byte type"""
- return self._to(torch.uint8)
- def bool(self):
- """Casts this storage to bool type"""
- return self._to(torch.bool)
- def bfloat16(self):
- """Casts this storage to bfloat16 type"""
- return self._to(torch.bfloat16)
- def complex_double(self):
- """Casts this storage to complex double type"""
- return self._to(torch.cdouble)
- def complex_float(self):
- """Casts this storage to complex float type"""
- return self._to(torch.cfloat)
- def pin_memory(self):
- """Copies the storage to pinned memory, if it's not already pinned."""
- if self.is_cuda:
- raise TypeError(f"cannot pin '{self.type()}' only CPU memory can be pinned")
- import torch.cuda
- allocator = torch.cuda.memory._host_allocator() # type: ignore[attr-defined]
- return type(self)(self.size(), allocator=allocator).copy_(self)
- def share_memory_(self):
- """Moves the storage to shared memory.
- This is a no-op for storages already in shared memory and for CUDA
- storages, which do not need to be moved for sharing across processes.
- Storages in shared memory cannot be resized.
- Returns: self
- """
- from torch.multiprocessing import get_sharing_strategy
- if self.is_cuda:
- pass # CUDA doesn't use POSIX shared memory
- elif get_sharing_strategy() == 'file_system':
- self._share_filename_cpu_()
- else:
- self._share_fd_cpu_()
- return self
- @classmethod
- def _new_shared(cls, size, *, device='cpu'):
- """Creates a new storage in shared memory with the same data type"""
- from torch.multiprocessing import get_sharing_strategy
- device = torch.device(device)
- if device.type == 'cuda':
- return cls(size, device=device)
- elif get_sharing_strategy() == 'file_system':
- return cls._new_using_filename_cpu(size)
- else:
- return cls._new_using_fd_cpu(size)
- def untyped(self):
- return self
- class UntypedStorage(torch._C.StorageBase, _StorageBase):
- def __getitem__(self, *args, **kwargs):
- if self.device.type == 'meta':
- raise NotImplementedError("Not available for 'meta' device type")
- return super().__getitem__(*args, **kwargs)
- @property
- def is_cuda(self):
- return self.device.type == 'cuda'
- def _load_from_bytes(b):
- return torch.load(io.BytesIO(b))
- _StorageBase.type = _type # type: ignore[assignment]
- _StorageBase.cuda = _cuda # type: ignore[assignment]
- @lru_cache(maxsize=None)
- def _dtype_to_storage_type_map():
- # NOTE: We should no longer add dtypes to this map. This map
- # is only used for BC/FC with older PyTorch versions. Going forward,
- # new dtypes of TypedStorage should not translate to a legacy
- # <type>Storage class. Instead, new dtypes of TypedStorage should
- # be serialized as an UntypedStorage paired with a torch.dtype
- return {
- torch.double: 'DoubleStorage',
- torch.float: 'FloatStorage',
- torch.half: 'HalfStorage',
- torch.long: 'LongStorage',
- torch.int: 'IntStorage',
- torch.int16: 'ShortStorage',
- torch.int8: 'CharStorage',
- torch.uint8: 'ByteStorage',
- torch.bool: 'BoolStorage',
- torch.bfloat16: 'BFloat16Storage',
- torch.cdouble: 'ComplexDoubleStorage',
- torch.cfloat: 'ComplexFloatStorage',
- torch.qint8: 'QInt8Storage',
- torch.qint32: 'QInt32Storage',
- torch.quint8: 'QUInt8Storage',
- torch.quint4x2: 'QUInt4x2Storage',
- torch.quint2x4: 'QUInt2x4Storage',
- }
- @lru_cache(maxsize=None)
- def _storage_type_to_dtype_map():
- dtype_map = {
- val: key for key, val in _dtype_to_storage_type_map().items()}
- return dtype_map
- def _get_storage_from_sequence(sequence, dtype, device):
- if dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
- interpret_dtypes = {
- torch.quint8: torch.uint8,
- torch.quint4x2: torch.uint8,
- torch.quint2x4: torch.uint8,
- torch.qint32: torch.int32,
- torch.qint8: torch.int8
- }
- tmp_tensor = torch.tensor(
- sequence,
- dtype=interpret_dtypes[dtype],
- device=device)
- else:
- tmp_tensor = torch.tensor(
- sequence,
- dtype=dtype,
- device=device)
- return tmp_tensor._typed_storage()._untyped_storage
- def _isint(x):
- if HAS_NUMPY:
- return isinstance(x, (int, np.integer))
- else:
- return isinstance(x, int)
- def _warn_typed_storage_removal(stacklevel=2):
- message = (
- "TypedStorage is deprecated. It will be removed in the future and "
- "UntypedStorage will be the only storage class. This should only matter "
- "to you if you are using storages directly. To access UntypedStorage "
- "directly, use tensor.untyped_storage() instead of tensor.storage()"
- )
- warnings.warn(message, UserWarning, stacklevel=stacklevel + 1)
- class TypedStorage:
- is_sparse = False
- dtype: torch.dtype
- @property
- def _dtype(self):
- return self.dtype
- def fill_(self, value):
- _warn_typed_storage_removal()
- self._setitem(slice(0, self._size()), value)
- return self
- def __new__(cls, *args, wrap_storage=None, dtype=None, device=None, _internal=False):
- if not _internal:
- _warn_typed_storage_removal()
- if cls == torch.storage._LegacyStorage:
- raise RuntimeError("Only child classes of _LegacyStorage can be instantiated")
- if cls == TypedStorage:
- return super().__new__(cls)
- else:
- arg_error_msg = (
- f'{cls}.__new__ received an invalid combination '
- f'of arguments. Expected one of:\n'
- ' * no arguments\n'
- ' * (int size)\n'
- ' * (Sequence data)\n'
- ' * (*, UntypedStorage wrap_storage)')
- if device is not None:
- raise RuntimeError(
- arg_error_msg +
- "\nKeyword argument 'device' cannot be specified")
- if dtype is not None:
- raise RuntimeError(
- arg_error_msg +
- "\nKeyword argument 'dtype' cannot be specified")
- if wrap_storage is None:
- if len(args) > 1:
- raise RuntimeError(
- arg_error_msg +
- "\nToo many positional arguments")
- if len(args) == 1 and not _isint(args[0]) and not isinstance(args[0], collections.abc.Sequence):
- raise TypeError(
- arg_error_msg +
- f"\nArgument type not recognized: {type(args[0])}")
- return TypedStorage(
- *args,
- dtype=cls._dtype,
- device='cuda' if cls.__module__ == 'torch.cuda' else 'cpu',
- _internal=True)
- else:
- if len(args) != 0:
- raise RuntimeError(
- arg_error_msg +
- "\nNo positional arguments should be given when using "
- "'wrap_storage'")
- if not isinstance(wrap_storage, torch.UntypedStorage):
- raise TypeError(
- arg_error_msg +
- f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}")
- cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
- if wrap_storage.device.type != cls_device:
- raise RuntimeError(
- arg_error_msg +
- f"\nDevice of 'wrap_storage' must be {cls_device}"
- f", but got {wrap_storage.device.type}")
- return TypedStorage(
- *args,
- wrap_storage=wrap_storage,
- dtype=cls.dtype,
- _internal=True)
- def __init__(self, *args, device=None, dtype=None, wrap_storage=None, _internal=False):
- if not _internal:
- _warn_typed_storage_removal()
- arg_error_msg = (
- 'TypedStorage.__init__ received an invalid combination '
- 'of arguments. Expected one of:\n'
- ' * (*, torch.device device, torch.dtype dtype)\n'
- ' * (int size, *, torch.device device, torch.dtype dtype)\n'
- ' * (Sequence data, *, torch.device device, torch.dtype dtype)\n'
- ' * (*, UntypedStorage wrap_storage, torch.dtype dtype)')
- if wrap_storage is not None:
- if len(args) != 0:
- raise RuntimeError(
- arg_error_msg +
- "\nNo positional arguments should be given when using "
- "'wrap_storage'")
- if dtype is None:
- raise RuntimeError(
- arg_error_msg +
- "\nArgument 'dtype' must be specified")
- if not isinstance(dtype, torch.dtype):
- raise TypeError(
- arg_error_msg +
- f"\nArgument 'dtype' must be torch.dtype, not {type(dtype)}")
- if device is not None:
- raise RuntimeError(
- arg_error_msg +
- "\nArgument 'device' should not be specified when 'wrap_storage' is given")
- self.dtype = dtype
- if not isinstance(wrap_storage, torch.UntypedStorage):
- raise TypeError(
- arg_error_msg +
- f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}")
- self._untyped_storage = wrap_storage
- else:
- self.dtype = torch.get_default_dtype() if dtype is None else dtype
- device = torch.device('cpu' if device is None else device)
- if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
- if device.type == 'cuda':
- raise RuntimeError("Cannot create CUDA storage with quantized dtype")
- if len(args) == 0:
- self._untyped_storage = torch.UntypedStorage(device=device)
- elif len(args) == 1:
- if _isint(args[0]):
- self._untyped_storage = torch.UntypedStorage(int(args[0]) * self._element_size(), device=device)
- elif isinstance(args[0], collections.abc.Sequence):
- self._untyped_storage = _get_storage_from_sequence(args[0], self.dtype, device)
- else:
- raise TypeError(
- arg_error_msg +
- f"\nArgument type not recognized: {type(args[0])}")
- else:
- raise RuntimeError(
- arg_error_msg +
- "\nToo many positional arguments")
- @property
- def is_cuda(self):
- _warn_typed_storage_removal()
- return self._untyped_storage.device.type == 'cuda'
- def untyped(self):
- """Returns the internal :class:`torch.UntypedStorage`"""
- _warn_typed_storage_removal()
- return self._untyped_storage
- def _new_wrapped_storage(self, untyped_storage):
- assert type(untyped_storage) == torch.UntypedStorage
- if type(self) == TypedStorage:
- return TypedStorage(
- wrap_storage=untyped_storage,
- dtype=self.dtype,
- _internal=True)
- else:
- return type(self)(wrap_storage=untyped_storage)
- def __len__(self):
- _warn_typed_storage_removal()
- return self._size()
- def _maybe_wrap_index(self, idx, is_stop=False):
- if idx is None:
- if is_stop:
- return self._size()
- else:
- return 0
- else:
- if type(idx) != int:
- raise TypeError(
- f"can't index a {type(self)} with {type(idx)}")
- if is_stop:
- if (idx > self._size()) or (idx < -self._size()):
- raise IndexError(
- f'index {idx} out of range for storage of size {self.size()}')
- if idx > 0:
- return idx
- else:
- return idx % self._size()
- else:
- if (idx >= self._size()) or (idx < -self._size()):
- raise IndexError(
- f'index {idx} out of range for storage of size {self.size()}')
- return idx % self._size()
- def __setitem__(self, idx, value):
- _warn_typed_storage_removal()
- return self._setitem(idx, value)
- def _setitem(self, idx, value):
- if not isinstance(idx, (int, slice)):
- raise RuntimeError(f"can't index a {type(self)} with {type(idx)}")
- if torch.is_storage(value):
- raise RuntimeError(f'cannot set item with value type {type(value)}')
- if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
- interpret_dtypes = {
- torch.quint8: torch.uint8,
- torch.quint4x2: torch.uint8,
- torch.quint2x4: torch.uint8,
- torch.qint32: torch.int32,
- torch.qint8: torch.int8
- }
- tmp_dtype = interpret_dtypes[self.dtype]
- tmp_tensor = torch.tensor([], dtype=tmp_dtype, device=self._untyped_storage.device)
- tmp_tensor.set_(TypedStorage(
- wrap_storage=self._untyped_storage,
- dtype=tmp_dtype,
- _internal=True))
- else:
- tmp_tensor = torch.tensor([], dtype=self.dtype, device=self._untyped_storage.device).set_(self)
- tmp_tensor[idx] = value
- def __getitem__(self, idx):
- _warn_typed_storage_removal()
- return self._getitem(idx)
- def _getitem(self, idx):
- if self._untyped_storage.device.type == 'meta':
- raise NotImplementedError("Not available for 'meta' device type")
- # NOTE: Before TypedStorage existed, indexing with a slice used to be
- # possible for <type>Storage objects. However, it would return
- # a storage view, which would be a hassle to implement in TypedStorage,
- # so it was disabled
- if isinstance(idx, slice):
- raise RuntimeError('slices are only supported in UntypedStorage.__getitem__')
- elif not isinstance(idx, int):
- raise RuntimeError(f"can't index a {type(self)} with {type(idx)}")
- if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
- interpret_dtypes = {
- torch.quint8: torch.uint8,
- torch.quint4x2: torch.uint8,
- torch.quint2x4: torch.uint8,
- torch.qint32: torch.int32,
- torch.qint8: torch.int8
- }
- return TypedStorage(
- wrap_storage=self._untyped_storage,
- dtype=interpret_dtypes[self.dtype],
- _internal=True)._getitem(idx)
- idx_wrapped = self._maybe_wrap_index(idx)
- tmp_tensor = torch.tensor([], dtype=self.dtype, device=self._untyped_storage.device).set_(self)
- return tmp_tensor[idx_wrapped].item()
- def copy_(self, source: T, non_blocking: bool = None):
- _warn_typed_storage_removal()
- if isinstance(source, TypedStorage):
- self._untyped_storage.copy_(source._untyped_storage, non_blocking)
- else:
- self._untyped_storage.copy_(source, non_blocking)
- return self
- def nbytes(self):
- _warn_typed_storage_removal()
- return self._nbytes()
- # For internal use only, to avoid deprecation warning
- def _nbytes(self):
- return self._untyped_storage.nbytes()
- def type(self, dtype: str = None, non_blocking: bool = False) -> Union[T, str]:
- _warn_typed_storage_removal()
- if dtype is None:
- legacy_class = self._get_legacy_storage_class()
- if legacy_class is not None:
- return legacy_class.__module__ + '.' + legacy_class.__name__
- return '.'.join([self.__module__, type(self).__name__])
- else:
- return self._untyped_storage.type(dtype, non_blocking)
- def cuda(self, device=None, non_blocking=False, **kwargs) -> T:
- _warn_typed_storage_removal()
- if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
- raise RuntimeError("Cannot create CUDA storage with quantized dtype")
- cuda_storage: torch.UntypedStorage = self._untyped_storage.cuda(device, non_blocking, **kwargs)
- return self._new_wrapped_storage(cuda_storage)
- def element_size(self):
- _warn_typed_storage_removal()
- return self._element_size()
- # For internal use only, to avoid deprecation warning
- def _element_size(self):
- return torch._utils._element_size(self.dtype)
- def get_device(self) -> int:
- _warn_typed_storage_removal()
- return self._untyped_storage.get_device()
- def __str__(self):
- _warn_typed_storage_removal()
- info_str = (
- f'[{torch.typename(self)}(dtype={self.dtype}, '
- f'device={self.device}) of size {len(self)}]')
- if self.device.type == 'meta':
- return '...\n' + info_str
- else:
- data_str = ' ' + '\n '.join(str(self[i]) for i in range(self.size()))
- return data_str + '\n' + info_str
- def __repr__(self):
- _warn_typed_storage_removal()
- return str(self)
- def __iter__(self):
- _warn_typed_storage_removal()
- return iter(map(lambda i: self[i], range(self.size())))
- def __copy__(self):
- _warn_typed_storage_removal()
- return self._new_wrapped_storage(copy.copy(self._untyped_storage))
- def __deepcopy__(self, memo):
- _warn_typed_storage_removal()
- return self._deepcopy(memo)
- # For internal use only, to avoid deprecation warning
- def _deepcopy(self, memo):
- return self._new_wrapped_storage(copy.deepcopy(self._untyped_storage, memo))
- def __sizeof__(self):
- _warn_typed_storage_removal()
- return super().__sizeof__() + self.nbytes()
- def clone(self):
- """Returns a copy of this storage"""
- _warn_typed_storage_removal()
- return self._new_wrapped_storage(self._untyped_storage.clone())
- def tolist(self):
- """Returns a list containing the elements of this storage"""
- _warn_typed_storage_removal()
- return list(self)
- def cpu(self):
- """Returns a CPU copy of this storage if it's not already on the CPU"""
- _warn_typed_storage_removal()
- return self._new_wrapped_storage(self._untyped_storage.cpu())
- def pin_memory(self):
- """Coppies the storage to pinned memory, if it's not already pinned."""
- _warn_typed_storage_removal()
- return self._new_wrapped_storage(self._untyped_storage.pin_memory())
- def share_memory_(self):
- """Moves the storage to shared memory.
- This is a no-op for storages already in shared memory and for CUDA
- storages, which do not need to be moved for sharing across processes.
- Storages in shared memory cannot be resized.
- Returns: self
- """
- _warn_typed_storage_removal()
- return self._share_memory_()
- # For internal use only, to avoid deprecation warning
- def _share_memory_(self):
- self._untyped_storage.share_memory_()
- return self
- def _new_shared(self, size, *, device=None):
- """Creates a new storage in shared memory with the same data type"""
- if device is None:
- device = 'cpu'
- device = torch.device(device)
- untyped_storage = torch.UntypedStorage._new_shared(size * self._element_size(), device=device)
- return TypedStorage(
- wrap_storage=untyped_storage,
- dtype=self.dtype,
- _internal=True)
- @property
- def _cdata(self):
- return self._untyped_storage._cdata
- @property
- def device(self):
- _warn_typed_storage_removal()
- return self._untyped_storage.device
- def size(self):
- _warn_typed_storage_removal()
- return self._size()
- # For internal use only, to avoid deprecation warning
- def _size(self):
- # NB: don't indirect through __len__, as that requires
- # an int to be returned
- return self._untyped_storage.nbytes() // self._element_size()
- def pickle_storage_type(self):
- _warn_typed_storage_removal()
- return self._pickle_storage_type()
- # For internal use only, to avoid deprecation warning
- def _pickle_storage_type(self):
- try:
- return _dtype_to_storage_type_map()[self.dtype]
- except KeyError as e:
- raise KeyError(f'dtype {self.dtype} is not recognized') from e
- def __reduce__(self):
- b = io.BytesIO()
- torch.save(self, b, _use_new_zipfile_serialization=False)
- return (_load_from_bytes, (b.getvalue(),))
- def data_ptr(self):
- _warn_typed_storage_removal()
- return self._data_ptr()
- # For internal use only, to avoid deprecation warning
- def _data_ptr(self):
- return self._untyped_storage.data_ptr()
- def resize_(self, size):
- _warn_typed_storage_removal()
- self._resize_(size)
- # For internal use only, to avoid deprecation warning
- def _resize_(self, size):
- self._untyped_storage.resize_(size * self._element_size())
- @classmethod
- def _free_weak_ref(cls, *args, **kwargs):
- return UntypedStorage._free_weak_ref(*args, **kwargs)
- def _weak_ref(self, *args, **kwargs):
- return self._untyped_storage._weak_ref(*args, **kwargs)
- @classmethod
- def from_buffer(cls, *args, **kwargs):
- _warn_typed_storage_removal()
- return cls._from_buffer(*args, **kwargs)
- @classmethod
- def _from_buffer(cls, *args, dtype=None, device=None, **kwargs):
- if cls == TypedStorage:
- dtype = torch.get_default_dtype() if dtype is None else dtype
- device = torch.device('cpu' if device is None else device)
- if device.type != 'cpu':
- raise RuntimeError(f'TypedStorage.from_buffer: Not available for device {device.type}')
- untyped_storage: torch.UntypedStorage = torch.UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
- else:
- if dtype is not None or len(args) == 5:
- raise RuntimeError((
- "from_buffer: 'dtype' can only be specified in "
- "UntypedStorage.from_buffer and TypedStorage.from_buffer"))
- if device is not None:
- raise RuntimeError((
- "from_buffer: 'device' can only be specified in "
- "UntypedStorage.from_buffer and TypedStorage.from_buffer"))
- dtype = cls._dtype
- untyped_storage = torch.UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
- return TypedStorage(
- wrap_storage=untyped_storage,
- dtype=dtype,
- _internal=True)
- def _to(self, dtype):
- if not isinstance(dtype, torch.dtype):
- raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
- storage = torch.tensor([], dtype=self.dtype, device=self.device).set_(self).to(dtype)._typed_storage()
- if storage.data_ptr() == self.data_ptr():
- storage = storage.clone()
- return storage
- def double(self):
- """Casts this storage to double type"""
- _warn_typed_storage_removal()
- return self._to(torch.double)
- def float(self):
- """Casts this storage to float type"""
- _warn_typed_storage_removal()
- return self._to(torch.float)
- def half(self):
- """Casts this storage to half type"""
- _warn_typed_storage_removal()
- return self._to(torch.half)
- def long(self):
- """Casts this storage to long type"""
- _warn_typed_storage_removal()
- return self._to(torch.long)
- def int(self):
- """Casts this storage to int type"""
- _warn_typed_storage_removal()
- return self._to(torch.int)
- def short(self):
- """Casts this storage to short type"""
- _warn_typed_storage_removal()
- return self._to(torch.short)
- def char(self):
- """Casts this storage to char type"""
- _warn_typed_storage_removal()
- return self._to(torch.int8)
- def byte(self):
- """Casts this storage to byte type"""
- _warn_typed_storage_removal()
- return self._to(torch.uint8)
- def bool(self):
- """Casts this storage to bool type"""
- _warn_typed_storage_removal()
- return self._to(torch.bool)
- def bfloat16(self):
- """Casts this storage to bfloat16 type"""
- _warn_typed_storage_removal()
- return self._to(torch.bfloat16)
- def complex_double(self):
- """Casts this storage to complex double type"""
- _warn_typed_storage_removal()
- return self._to(torch.cdouble)
- def complex_float(self):
- """Casts this storage to complex float type"""
- _warn_typed_storage_removal()
- return self._to(torch.cfloat)
- @classmethod
- def from_file(cls, filename, shared, size):
- """
- from_file(filename, shared=False, size=0) -> Storage
- If `shared` is `True`, then memory is shared between all processes.
- All changes are written to the file. If `shared` is `False`, then the changes on
- the storage do not affect the file.
- `size` is the number of elements in the storage. If `shared` is `False`,
- then the file must contain at least `size * sizeof(Type)` bytes
- (`Type` is the type of storage). If `shared` is `True` the file will be
- created if needed.
- Args:
- filename (str): file name to map
- shared (bool): whether to share memory
- size (int): number of elements in the storage
- """
- _warn_typed_storage_removal()
- if cls == TypedStorage:
- raise RuntimeError('from_file can only be called on derived classes')
- untyped_storage: UntypedStorage = UntypedStorage.from_file(
- filename,
- shared,
- size * torch._utils._element_size(cls.dtype))
- storage = cls(wrap_storage=untyped_storage)
- return storage
- @classmethod
- def _expired(cls, *args, **kwargs):
- return UntypedStorage._expired(*args, **kwargs)
- def is_pinned(self):
- _warn_typed_storage_removal()
- return self._untyped_storage.is_pinned()
- def _write_file(self, *args, **kwargs):
- return self._untyped_storage._write_file(*args, **kwargs)
- def _set_from_file(self, *args, **kwargs):
- return self._untyped_storage._set_from_file(*args, **kwargs)
- def _set_cdata(self, *args, **kwargs):
- return self._untyped_storage._set_cdata(*args, **kwargs)
- def _share_cuda_(self, *args, **kwargs):
- return self._untyped_storage._share_cuda_(*args, **kwargs)
- def is_shared(self):
- _warn_typed_storage_removal()
- return self._is_shared()
- # For internal use only, to avoid deprecation warning
- def _is_shared(self):
- return self._untyped_storage.is_shared()
- @classmethod
- def _new_shared_cuda(cls, *args, **kwargs):
- return torch.UntypedStorage._new_shared_cuda(*args, **kwargs)
- def _share_filename_cpu_(self, *args, **kwargs):
- manager_handle, storage_handle, size = self._untyped_storage._share_filename_cpu_(*args, **kwargs)
- return manager_handle, storage_handle, size // self._element_size()
- def _shared_decref(self):
- self._untyped_storage._shared_decref()
- return self
- @classmethod
- def _release_ipc_counter(cls, *args, device=None, **kwargs):
- return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
- def _shared_incref(self, *args, **kwargs):
- return self._untyped_storage._shared_incref(*args, **kwargs)
- def _share_fd_cpu_(self, *args, **kwargs):
- fd, size = self._untyped_storage._share_fd_cpu_(*args, **kwargs)
- return fd, size // self._element_size()
- def _get_legacy_storage_class(self):
- if self.dtype not in _dtype_to_storage_type_map():
- return None
- storage_name = _dtype_to_storage_type_map()[self.dtype]
- if self.device.type not in ['cpu', 'cuda']:
- return None
- module = torch if self.device.type == 'cpu' else torch.cuda
- try:
- return getattr(module, storage_name)
- except AttributeError:
- return None
- TypedStorage.type.__doc__ = _type.__doc__
- TypedStorage.cuda.__doc__ = _cuda.__doc__
- class _LegacyStorageMeta(type):
- dtype: torch.dtype
- def __instancecheck__(cls, instance):
- if type(instance) == TypedStorage:
- cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
- return (cls_device == instance.device.type) and (cls.dtype == instance.dtype)
- return False
- class _LegacyStorage(TypedStorage, metaclass=_LegacyStorageMeta):
- @classmethod
- def _new_shared(cls, size):
- """Creates a new storage in shared memory with the same data type"""
- untyped_storage = torch.UntypedStorage._new_shared(size * cls()._element_size())
- return cls(wrap_storage=untyped_storage)
- @classmethod
- def _release_ipc_counter(cls, *args, **kwargs):
- return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
- @classmethod
- def _new_shared_filename(cls, manager, obj, size):
- bytes_size = size * torch._utils._element_size(cls.dtype)
- return cls(wrap_storage=torch.UntypedStorage._new_shared_filename_cpu(manager, obj, bytes_size))
- def _get_dtype_from_pickle_storage_type(pickle_storage_type: str):
- try:
- return _storage_type_to_dtype_map()[pickle_storage_type]
- except KeyError as e:
- raise KeyError(
- f'pickle storage type "{pickle_storage_type}" is not recognized') from e
|