__init__.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. # The Tensor classes are added to this module by python_tensor.cpp
  2. from typing import Optional, Tuple, List, Union
  3. import torch
  4. from torch._C import _add_docstr, _sparse # type: ignore[attr-defined]
  5. from torch import Tensor
  6. # A workaround to support both TorchScript and MyPy:
  7. from typing import TYPE_CHECKING
  8. if TYPE_CHECKING:
  9. from torch.types import _dtype as DType
  10. DimOrDims = Optional[Union[int, Tuple[int], List[int]]]
  11. else:
  12. # The JIT doesn't understand Union, nor torch.dtype here
  13. DType = int
  14. DimOrDims = Optional[Tuple[int]]
  15. __all__ = [
  16. 'addmm',
  17. 'check_sparse_tensor_invariants',
  18. 'mm',
  19. 'sum',
  20. 'softmax',
  21. 'log_softmax',
  22. ]
  23. addmm = _add_docstr(_sparse._sparse_addmm, r"""
  24. sparse.addmm(mat, mat1, mat2, *, beta=1., alpha=1.) -> Tensor
  25. This function does exact same thing as :func:`torch.addmm` in the forward,
  26. except that it supports backward for sparse COO matrix :attr:`mat1`.
  27. When :attr:`mat1` is a COO tensor it must have `sparse_dim = 2`.
  28. When inputs are COO tensors, this function also supports backward for both inputs.
  29. Supports both CSR and COO storage formats.
  30. .. note::
  31. This function doesn't support computing derivaties with respect to CSR matrices.
  32. Args:
  33. mat (Tensor): a dense matrix to be added
  34. mat1 (Tensor): a sparse matrix to be multiplied
  35. mat2 (Tensor): a dense matrix to be multiplied
  36. beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`)
  37. alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
  38. """)
  39. mm = _add_docstr(_sparse._sparse_mm, r"""
  40. Performs a matrix multiplication of the sparse matrix :attr:`mat1`
  41. and the (sparse or strided) matrix :attr:`mat2`. Similar to :func:`torch.mm`, if :attr:`mat1` is a
  42. :math:`(n \times m)` tensor, :attr:`mat2` is a :math:`(m \times p)` tensor, out will be a
  43. :math:`(n \times p)` tensor.
  44. When :attr:`mat1` is a COO tensor it must have `sparse_dim = 2`.
  45. When inputs are COO tensors, this function also supports backward for both inputs.
  46. Supports both CSR and COO storage formats.
  47. .. note::
  48. This function doesn't support computing derivaties with respect to CSR matrices.
  49. This function also additionally accepts an optional :attr:`reduce` argument that allows
  50. specification of an optional reduction operation, mathematically performs the following operation:
  51. .. math::
  52. z_{ij} = \bigoplus_{k = 0}^{K - 1} x_{ik} y_{kj}
  53. where :math:`\bigoplus` defines the reduce operator. :attr:`reduce` is implemented only for
  54. CSR storage format on CPU device.
  55. Args:
  56. mat1 (Tensor): the first sparse matrix to be multiplied
  57. mat2 (Tensor): the second matrix to be multiplied, which could be sparse or dense
  58. reduce (str, optional): the reduction operation to apply for non-unique indices
  59. (:obj:`"sum"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`). Default :obj:`"sum"`.
  60. Shape:
  61. The format of the output tensor of this function follows:
  62. - sparse x sparse -> sparse
  63. - sparse x dense -> dense
  64. Example::
  65. >>> a = torch.tensor([[1., 0, 2], [0, 3, 0]]).to_sparse().requires_grad_()
  66. >>> a
  67. tensor(indices=tensor([[0, 0, 1],
  68. [0, 2, 1]]),
  69. values=tensor([1., 2., 3.]),
  70. size=(2, 3), nnz=3, layout=torch.sparse_coo, requires_grad=True)
  71. >>> b = torch.tensor([[0, 1.], [2, 0], [0, 0]], requires_grad=True)
  72. >>> b
  73. tensor([[0., 1.],
  74. [2., 0.],
  75. [0., 0.]], requires_grad=True)
  76. >>> y = torch.sparse.mm(a, b)
  77. >>> y
  78. tensor([[0., 1.],
  79. [6., 0.]], grad_fn=<SparseAddmmBackward0>)
  80. >>> y.sum().backward()
  81. >>> a.grad
  82. tensor(indices=tensor([[0, 0, 1],
  83. [0, 2, 1]]),
  84. values=tensor([1., 0., 2.]),
  85. size=(2, 3), nnz=3, layout=torch.sparse_coo)
  86. >>> c = a.detach().to_sparse_csr()
  87. >>> c
  88. tensor(crow_indices=tensor([0, 2, 3]),
  89. col_indices=tensor([0, 2, 1]),
  90. values=tensor([1., 2., 3.]), size=(2, 3), nnz=3,
  91. layout=torch.sparse_csr)
  92. >>> y1 = torch.sparse.mm(c, b, 'sum')
  93. >>> y1
  94. tensor([[0., 1.],
  95. [6., 0.]], grad_fn=<SparseMmReduceImplBackward0>)
  96. >>> y2 = torch.sparse.mm(c, b, 'max')
  97. >>> y2
  98. tensor([[0., 1.],
  99. [6., 0.]], grad_fn=<SparseMmReduceImplBackward0>)
  100. """)
  101. sampled_addmm = _add_docstr(_sparse.sparse_sampled_addmm, r"""
  102. sparse.sampled_addmm(input, mat1, mat2, *, beta=1., alpha=1., out=None) -> Tensor
  103. Performs a matrix multiplication of the dense matrices :attr:`mat1` and :attr:`mat2` at the locations
  104. specified by the sparsity pattern of :attr:`input`. The matrix :attr:`input` is added to the final result.
  105. Mathematically this performs the following operation:
  106. .. math::
  107. \text{out} = \alpha\ (\text{mat1} \mathbin{@} \text{mat2})*\text{spy}(\text{input}) + \beta\ \text{input}
  108. where :math:`\text{spy}(\text{input})` is the sparsity pattern matrix of :attr:`input`, :attr:`alpha`
  109. and :attr:`beta` are the scaling factors.
  110. :math:`\text{spy}(\text{input})` has value 1 at the positions where :attr:`input` has non-zero values, and 0 elsewhere.
  111. .. note::
  112. :attr:`input` must be a sparse CSR tensor. :attr:`mat1` and :attr:`mat2` must be dense tensors.
  113. Args:
  114. input (Tensor): a sparse CSR matrix of shape `(m, n)` to be added and used to compute
  115. the sampled matrix multiplication
  116. mat1 (Tensor): a dense matrix of shape `(m, k)` to be multiplied
  117. mat2 (Tensor): a dense matrix of shape `(k, n)` to be multiplied
  118. Keyword args:
  119. beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`)
  120. alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
  121. out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`.
  122. Examples::
  123. >>> input = torch.eye(3, device='cuda').to_sparse_csr()
  124. >>> mat1 = torch.randn(3, 5, device='cuda')
  125. >>> mat2 = torch.randn(5, 3, device='cuda')
  126. >>> torch.sparse.sampled_addmm(input, mat1, mat2)
  127. tensor(crow_indices=tensor([0, 1, 2, 3]),
  128. col_indices=tensor([0, 1, 2]),
  129. values=tensor([ 0.2847, -0.7805, -0.1900]), device='cuda:0',
  130. size=(3, 3), nnz=3, layout=torch.sparse_csr)
  131. >>> torch.sparse.sampled_addmm(input, mat1, mat2).to_dense()
  132. tensor([[ 0.2847, 0.0000, 0.0000],
  133. [ 0.0000, -0.7805, 0.0000],
  134. [ 0.0000, 0.0000, -0.1900]], device='cuda:0')
  135. >>> torch.sparse.sampled_addmm(input, mat1, mat2, beta=0.5, alpha=0.5)
  136. tensor(crow_indices=tensor([0, 1, 2, 3]),
  137. col_indices=tensor([0, 1, 2]),
  138. values=tensor([ 0.1423, -0.3903, -0.0950]), device='cuda:0',
  139. size=(3, 3), nnz=3, layout=torch.sparse_csr)
  140. """)
  141. def sum(input: Tensor, dim: DimOrDims = None,
  142. dtype: Optional[DType] = None) -> Tensor:
  143. r"""
  144. Returns the sum of each row of the sparse tensor :attr:`input` in the given
  145. dimensions :attr:`dim`. If :attr:`dim` is a list of dimensions,
  146. reduce over all of them. When sum over all ``sparse_dim``, this method
  147. returns a dense tensor instead of a sparse tensor.
  148. All summed :attr:`dim` are squeezed (see :func:`torch.squeeze`), resulting an output
  149. tensor having :attr:`dim` fewer dimensions than :attr:`input`.
  150. During backward, only gradients at ``nnz`` locations of :attr:`input`
  151. will propagate back. Note that the gradients of :attr:`input` is coalesced.
  152. Args:
  153. input (Tensor): the input sparse tensor
  154. dim (int or tuple of ints): a dimension or a list of dimensions to reduce. Default: reduce
  155. over all dims.
  156. dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor.
  157. Default: dtype of :attr:`input`.
  158. Example::
  159. >>> nnz = 3
  160. >>> dims = [5, 5, 2, 3]
  161. >>> I = torch.cat([torch.randint(0, dims[0], size=(nnz,)),
  162. torch.randint(0, dims[1], size=(nnz,))], 0).reshape(2, nnz)
  163. >>> V = torch.randn(nnz, dims[2], dims[3])
  164. >>> size = torch.Size(dims)
  165. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  166. >>> S = torch.sparse_coo_tensor(I, V, size)
  167. >>> S
  168. tensor(indices=tensor([[2, 0, 3],
  169. [2, 4, 1]]),
  170. values=tensor([[[-0.6438, -1.6467, 1.4004],
  171. [ 0.3411, 0.0918, -0.2312]],
  172. [[ 0.5348, 0.0634, -2.0494],
  173. [-0.7125, -1.0646, 2.1844]],
  174. [[ 0.1276, 0.1874, -0.6334],
  175. [-1.9682, -0.5340, 0.7483]]]),
  176. size=(5, 5, 2, 3), nnz=3, layout=torch.sparse_coo)
  177. # when sum over only part of sparse_dims, return a sparse tensor
  178. >>> torch.sparse.sum(S, [1, 3])
  179. tensor(indices=tensor([[0, 2, 3]]),
  180. values=tensor([[-1.4512, 0.4073],
  181. [-0.8901, 0.2017],
  182. [-0.3183, -1.7539]]),
  183. size=(5, 2), nnz=3, layout=torch.sparse_coo)
  184. # when sum over all sparse dim, return a dense tensor
  185. # with summed dims squeezed
  186. >>> torch.sparse.sum(S, [0, 1, 3])
  187. tensor([-2.6596, -1.1450])
  188. """
  189. if dtype is None:
  190. if dim is not None:
  191. return torch._sparse_sum(input, dim)
  192. else:
  193. return torch._sparse_sum(input)
  194. else:
  195. if dim is not None:
  196. return torch._sparse_sum(input, dim, dtype=dtype)
  197. else:
  198. return torch._sparse_sum(input, dtype=dtype)
  199. softmax = _add_docstr(_sparse._sparse_softmax, r"""
  200. sparse.softmax(input, dim, *, dtype=None) -> Tensor
  201. Applies a softmax function.
  202. Softmax is defined as:
  203. :math:`\text{Softmax}(x_{i}) = \frac{exp(x_i)}{\sum_j exp(x_j)}`
  204. where :math:`i, j` run over sparse tensor indices and unspecified
  205. entries are ignores. This is equivalent to defining unspecified
  206. entries as negative infinity so that :math:`exp(x_k) = 0` when the
  207. entry with index :math:`k` has not specified.
  208. It is applied to all slices along `dim`, and will re-scale them so
  209. that the elements lie in the range `[0, 1]` and sum to 1.
  210. Args:
  211. input (Tensor): input
  212. dim (int): A dimension along which softmax will be computed.
  213. dtype (:class:`torch.dtype`, optional): the desired data type
  214. of returned tensor. If specified, the input tensor is
  215. casted to :attr:`dtype` before the operation is
  216. performed. This is useful for preventing data type
  217. overflows. Default: None
  218. """)
  219. log_softmax = _add_docstr(_sparse._sparse_log_softmax, r"""
  220. sparse.log_softmax(input, dim, *, dtype=None) -> Tensor
  221. Applies a softmax function followed by logarithm.
  222. See :class:`~torch.sparse.softmax` for more details.
  223. Args:
  224. input (Tensor): input
  225. dim (int): A dimension along which softmax will be computed.
  226. dtype (:class:`torch.dtype`, optional): the desired data type
  227. of returned tensor. If specified, the input tensor is
  228. casted to :attr:`dtype` before the operation is
  229. performed. This is useful for preventing data type
  230. overflows. Default: None
  231. """)
  232. spdiags = _add_docstr(
  233. _sparse._spdiags,
  234. r"""
  235. sparse.spdiags(diagonals, offsets, shape, layout=None) -> Tensor
  236. Creates a sparse 2D tensor by placing the values from rows of
  237. :attr:`diagonals` along specified diagonals of the output
  238. The :attr:`offsets` tensor controls which diagonals are set.
  239. - If :attr:`offsets[i]` = 0, it is the main diagonal
  240. - If :attr:`offsets[i]` < 0, it is below the main diagonal
  241. - If :attr:`offsets[i]` > 0, it is above the main diagonal
  242. The number of rows in :attr:`diagonals` must match the length of :attr:`offsets`,
  243. and an offset may not be repeated.
  244. Args:
  245. diagonals (Tensor): Matrix storing diagonals row-wise
  246. offsets (Tensor): The diagonals to be set, stored as a vector
  247. shape (2-tuple of ints): The desired shape of the result
  248. Keyword args:
  249. layout (:class:`torch.layout`, optional): The desired layout of the
  250. returned tensor. ``torch.sparse_coo``, ``torch.sparse_csc`` and ``torch.sparse_csr``
  251. are supported. Default: ``torch.sparse_coo``
  252. Examples:
  253. Set the main and first two lower diagonals of a matrix::
  254. >>> diags = torch.arange(9).reshape(3, 3)
  255. >>> diags
  256. tensor([[0, 1, 2],
  257. [3, 4, 5],
  258. [6, 7, 8]])
  259. >>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3))
  260. >>> s
  261. tensor(indices=tensor([[0, 1, 2, 1, 2, 2],
  262. [0, 1, 2, 0, 1, 0]]),
  263. values=tensor([0, 1, 2, 3, 4, 6]),
  264. size=(3, 3), nnz=6, layout=torch.sparse_coo)
  265. >>> s.to_dense()
  266. tensor([[0, 0, 0],
  267. [3, 1, 0],
  268. [6, 4, 2]])
  269. Change the output layout::
  270. >>> diags = torch.arange(9).reshape(3, 3)
  271. >>> diags
  272. tensor([[0, 1, 2],[3, 4, 5], [6, 7, 8])
  273. >>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3), layout=torch.sparse_csr)
  274. >>> s
  275. tensor(crow_indices=tensor([0, 1, 3, 6]),
  276. col_indices=tensor([0, 0, 1, 0, 1, 2]),
  277. values=tensor([0, 3, 1, 6, 4, 2]), size=(3, 3), nnz=6,
  278. layout=torch.sparse_csr)
  279. >>> s.to_dense()
  280. tensor([[0, 0, 0],
  281. [3, 1, 0],
  282. [6, 4, 2]])
  283. Set partial diagonals of a large output::
  284. >>> diags = torch.tensor([[1, 2], [3, 4]])
  285. >>> offsets = torch.tensor([0, -1])
  286. >>> torch.sparse.spdiags(diags, offsets, (5, 5)).to_dense()
  287. tensor([[1, 0, 0, 0, 0],
  288. [3, 2, 0, 0, 0],
  289. [0, 4, 0, 0, 0],
  290. [0, 0, 0, 0, 0],
  291. [0, 0, 0, 0, 0]])
  292. .. note::
  293. When setting the values along a given diagonal the index into the diagonal
  294. and the index into the row of :attr:`diagonals` is taken as the
  295. column index in the output. This has the effect that when setting a diagonal
  296. with a positive offset `k` the first value along that diagonal will be
  297. the value in position `k` of the row of :attr:`diagonals`
  298. Specifying a positive offset::
  299. >>> diags = torch.tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
  300. >>> torch.sparse.spdiags(diags, torch.tensor([0, 1, 2]), (5, 5)).to_dense()
  301. tensor([[1, 2, 3, 0, 0],
  302. [0, 2, 3, 0, 0],
  303. [0, 0, 3, 0, 0],
  304. [0, 0, 0, 0, 0],
  305. [0, 0, 0, 0, 0]])
  306. """)
  307. class check_sparse_tensor_invariants:
  308. """A tool to control checking sparse tensor invariants.
  309. The following options exists to manage sparsr tensor invariants
  310. checking in sparse tensor construction:
  311. 1. Using a context manager:
  312. .. code:: python
  313. with torch.sparse.check_sparse_tensor_invariants():
  314. run_my_model()
  315. 2. Using a procedural approach:
  316. .. code:: python
  317. prev_checks_enabled = torch.sparse.check_sparse_tensor_invariants.is_enabled()
  318. torch.sparse.check_sparse_tensor_invariants.enable()
  319. run_my_model()
  320. if not prev_checks_enabled:
  321. torch.sparse.check_sparse_tensor_invariants.disable()
  322. 3. Using function decoration:
  323. .. code:: python
  324. @torch.sparse.check_sparse_tensor_invariants()
  325. def run_my_model():
  326. ...
  327. run_my_model()
  328. 4. Using ``check_invariants`` keyword argument in sparse tensor constructor call.
  329. For example:
  330. >>> torch.sparse_csr_tensor([0, 1, 3], [0, 1], [1, 2], check_invariants=True)
  331. Traceback (most recent call last):
  332. File "<stdin>", line 1, in <module>
  333. RuntimeError: `crow_indices[..., -1] == nnz` is not satisfied.
  334. """
  335. @staticmethod
  336. def is_enabled():
  337. r"""Returns True if the sparse tensor invariants checking is enabled.
  338. .. note::
  339. Use :func:`torch.sparse.check_sparse_tensor_invariants.enable` or
  340. :func:`torch.sparse.check_sparse_tensor_invariants.disable` to
  341. manage the state of the sparse tensor invariants checks.
  342. """
  343. return torch._C._check_sparse_tensor_invariants()
  344. @staticmethod
  345. def enable():
  346. r"""Enable sparse tensor invariants checking in sparse tensor constructors.
  347. .. note::
  348. By default, the sparse tensor invariants checks are disabled. Use
  349. :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled` to
  350. retrieve the current state of sparse tensor invariants checking.
  351. .. note::
  352. The sparse tensor invariants check flag is effective to all sparse
  353. tensor constructors, both in Python and ATen.
  354. The flag can be locally overridden by the ``check_invariants``
  355. optional argument of the sparse tensor constructor functions.
  356. """
  357. torch._C._set_check_sparse_tensor_invariants(True)
  358. @staticmethod
  359. def disable():
  360. r"""Disable sparse tensor invariants checking in sparse tensor constructors.
  361. See :func:`torch.sparse.check_sparse_tensor_invariants.enable` for more information.
  362. """
  363. torch._C._set_check_sparse_tensor_invariants(False)
  364. # context manager support
  365. def __init__(self, enable=True):
  366. self.state = enable
  367. self.saved_state = self.is_enabled()
  368. def __enter__(self):
  369. torch._C._set_check_sparse_tensor_invariants(self.state)
  370. def __exit__(self, type, value, traceback):
  371. torch._C._set_check_sparse_tensor_invariants(self.saved_state)
  372. # decorator support
  373. def __call__(self, mth):
  374. def test_mth(*args, **kwargs):
  375. with type(self)(self.state):
  376. return mth(*args, **kwargs)
  377. return test_mth