__init__.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import torch
  2. from typing import Union, Sequence
  3. import inspect
  4. import dis
  5. from .tree_map import tree_flatten, tree_map
  6. from .wrap_type import wrap_type
  7. from functorch._C import dim as _C
  8. _C._patch_tensor_class()
  9. dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists
  10. class DimensionMismatchError(Exception):
  11. pass
  12. class DimensionBindError(Exception):
  13. pass
  14. from . import op_properties
  15. # use dict to avoid writing C++ bindings for set
  16. pointwise = {t: True for t in op_properties.pointwise}
  17. use_c = True
  18. if not use_c:
  19. from . import reference
  20. class _Tensor:
  21. # fast path around slow wrapping/unwrapping logic for simply queries used
  22. # by the implementation...
  23. @property
  24. def dims(self):
  25. return tuple(d for d in self._levels if isinstance(d, Dim))
  26. def dim(self):
  27. return self.ndim
  28. if use_c:
  29. __torch_function__ = classmethod(_C.__torch_function__)
  30. expand = _C._instancemethod(_C.expand)
  31. else:
  32. __torch_function__ = reference.__torch_function__
  33. expand = reference.expand
  34. index = _C._instancemethod(_C.index)
  35. def __repr__(self):
  36. tensor, levels, ndim = self._tensor, self._levels, self.ndim
  37. return f'{tensor}\nwith dims={tuple(l + ndim if isinstance(l, int) else l for l in levels)} sizes={tuple(tensor.size())}'
  38. TensorLike = (_Tensor, torch.Tensor)
  39. class Dim(_C.Dim, _Tensor):
  40. # note that _C.Dim comes before tensor because we want the Dim API for things like size to take precendence.
  41. # Tensor defines format, but we want to print Dims with special formatting
  42. __format__ = object.__format__
  43. class Tensor(_Tensor, _C.Tensor):
  44. if not use_c:
  45. from_batched = staticmethod(_C.Tensor_from_batched)
  46. from_positional = staticmethod(_C.Tensor_from_positional)
  47. sum = _C._instancemethod(_C.Tensor_sum)
  48. def cat(tensors, dim, new_dim):
  49. n = dims()
  50. return stack(tensors, n, dim).index([n, dim], new_dim)
  51. if use_c:
  52. _wrap = _C._wrap
  53. def _def(name, *args, **kwargs):
  54. orig = getattr(torch.Tensor, name)
  55. setattr(_Tensor, name, _C._instancemethod(_wrap(orig, *args, **kwargs)))
  56. t__getitem__ = _C._instancemethod(_C.__getitem__)
  57. stack = _C.stack
  58. split = _C._instancemethod(_C.split)
  59. else:
  60. _wrap, _def = reference._wrap, reference._def
  61. t__getitem__ = reference.t__getitem__
  62. stack = reference.stack
  63. split = reference.split
  64. # note: there is no python reference
  65. t__setitem__ = _C._instancemethod(_C.__setitem__)
  66. # this is patched in the C API because otherwise torch.Tensor will
  67. # no longer be considered a sequence and things will break
  68. # torch.Tensor.__getitem__ = t__getitem__
  69. _Tensor.__getitem__ = t__getitem__
  70. # torch.Tensor.__setitem__ = t__setitem__
  71. _Tensor.__setitem__ = t__setitem__
  72. torch.Tensor.split = split
  73. _Tensor.split = split
  74. torch.Tensor.expand = _C._instancemethod(_C.expand)
  75. torch.Tensor.index = _C._instancemethod(_C.index)
  76. wrap_type(use_c, _Tensor, torch.Tensor, _Tensor.__torch_function__)
  77. del _Tensor.ndim
  78. if use_c:
  79. _Tensor.order = _C._instancemethod(_C.order)
  80. else:
  81. _Tensor.order = reference.positional
  82. _def('mean')
  83. _def('sum')
  84. _def('all')
  85. _def('amax')
  86. _def('amin')
  87. _def('aminmax')
  88. _def('any')
  89. _def('count_nonzero')
  90. _def('logsumexp')
  91. _def('nanmean')
  92. _def('nansum')
  93. _def('prod')
  94. _def('std', keepdim_offset=2)
  95. _def('var', keepdim_offset=2)
  96. _def('max', single_dim=True)
  97. _def('min', single_dim=True)
  98. _def('argmax', single_dim=True)
  99. _def('argmin', single_dim=True)
  100. _def('kthvalue', single_dim=True)
  101. _def('median', single_dim=True)
  102. _def('nanmedian', single_dim=True)
  103. _def('mode', single_dim=True)
  104. _def('sort', reduce=False)
  105. _def('argsort', reduce=False)
  106. _def('unbind', single_dim=True)
  107. _def('chunk', dim_offset=1, reduce=False)
  108. _def('cummax', single_dim=True, reduce=False)
  109. _def('cummin', single_dim=True, reduce=False)
  110. _def('cumprod', single_dim=True, reduce=False)
  111. _def('cumprod_', single_dim=True, reduce=False)
  112. _def('cumsum', single_dim=True, reduce=False)
  113. _def('cumsum_', single_dim=True, reduce=False)
  114. _def('logcumsumexp', single_dim=True, reduce=False)
  115. _def('renorm', dim_offset=1, single_dim=True, reduce=False)
  116. _def('softmax', single_dim=True, reduce=False)
  117. softmax = _wrap(torch.nn.functional.softmax, single_dim=True, reduce=False)
  118. # stuff to handle in the future, because they require special
  119. # binding logic for dims
  120. # cross
  121. # diag_embed
  122. # diagonal
  123. # diagonal_scatter
  124. # diff
  125. # nanquantile
  126. # quantile
  127. # roll
  128. # rot90
  129. # topk (new dimes on output)
  130. # should these all be subsumed by inplace indexing?
  131. # index_add_
  132. # index_add
  133. # index_copy
  134. # index_copy_
  135. # index_fill
  136. # index_fill_
  137. # index_select
  138. # scatter
  139. # scatter_
  140. # scatter_add
  141. # scatter_add_
  142. # scatter_reduce