dense_ndim_array.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import functools
  2. from typing import List
  3. from sympy.core.basic import Basic
  4. from sympy.core.containers import Tuple
  5. from sympy.core.singleton import S
  6. from sympy.core.sympify import _sympify
  7. from sympy.tensor.array.mutable_ndim_array import MutableNDimArray
  8. from sympy.tensor.array.ndim_array import NDimArray, ImmutableNDimArray, ArrayKind
  9. from sympy.utilities.iterables import flatten
  10. class DenseNDimArray(NDimArray):
  11. _array: List[Basic]
  12. def __new__(self, *args, **kwargs):
  13. return ImmutableDenseNDimArray(*args, **kwargs)
  14. @property
  15. def kind(self) -> ArrayKind:
  16. return ArrayKind._union(self._array)
  17. def __getitem__(self, index):
  18. """
  19. Allows to get items from N-dim array.
  20. Examples
  21. ========
  22. >>> from sympy import MutableDenseNDimArray
  23. >>> a = MutableDenseNDimArray([0, 1, 2, 3], (2, 2))
  24. >>> a
  25. [[0, 1], [2, 3]]
  26. >>> a[0, 0]
  27. 0
  28. >>> a[1, 1]
  29. 3
  30. >>> a[0]
  31. [0, 1]
  32. >>> a[1]
  33. [2, 3]
  34. Symbolic index:
  35. >>> from sympy.abc import i, j
  36. >>> a[i, j]
  37. [[0, 1], [2, 3]][i, j]
  38. Replace `i` and `j` to get element `(1, 1)`:
  39. >>> a[i, j].subs({i: 1, j: 1})
  40. 3
  41. """
  42. syindex = self._check_symbolic_index(index)
  43. if syindex is not None:
  44. return syindex
  45. index = self._check_index_for_getitem(index)
  46. if isinstance(index, tuple) and any(isinstance(i, slice) for i in index):
  47. sl_factors, eindices = self._get_slice_data_for_array_access(index)
  48. array = [self._array[self._parse_index(i)] for i in eindices]
  49. nshape = [len(el) for i, el in enumerate(sl_factors) if isinstance(index[i], slice)]
  50. return type(self)(array, nshape)
  51. else:
  52. index = self._parse_index(index)
  53. return self._array[index]
  54. @classmethod
  55. def zeros(cls, *shape):
  56. list_length = functools.reduce(lambda x, y: x*y, shape, S.One)
  57. return cls._new(([0]*list_length,), shape)
  58. def tomatrix(self):
  59. """
  60. Converts MutableDenseNDimArray to Matrix. Can convert only 2-dim array, else will raise error.
  61. Examples
  62. ========
  63. >>> from sympy import MutableDenseNDimArray
  64. >>> a = MutableDenseNDimArray([1 for i in range(9)], (3, 3))
  65. >>> b = a.tomatrix()
  66. >>> b
  67. Matrix([
  68. [1, 1, 1],
  69. [1, 1, 1],
  70. [1, 1, 1]])
  71. """
  72. from sympy.matrices import Matrix
  73. if self.rank() != 2:
  74. raise ValueError('Dimensions must be of size of 2')
  75. return Matrix(self.shape[0], self.shape[1], self._array)
  76. def reshape(self, *newshape):
  77. """
  78. Returns MutableDenseNDimArray instance with new shape. Elements number
  79. must be suitable to new shape. The only argument of method sets
  80. new shape.
  81. Examples
  82. ========
  83. >>> from sympy import MutableDenseNDimArray
  84. >>> a = MutableDenseNDimArray([1, 2, 3, 4, 5, 6], (2, 3))
  85. >>> a.shape
  86. (2, 3)
  87. >>> a
  88. [[1, 2, 3], [4, 5, 6]]
  89. >>> b = a.reshape(3, 2)
  90. >>> b.shape
  91. (3, 2)
  92. >>> b
  93. [[1, 2], [3, 4], [5, 6]]
  94. """
  95. new_total_size = functools.reduce(lambda x,y: x*y, newshape)
  96. if new_total_size != self._loop_size:
  97. raise ValueError('Expecting reshape size to %d but got prod(%s) = %d' % (
  98. self._loop_size, str(newshape), new_total_size))
  99. # there is no `.func` as this class does not subtype `Basic`:
  100. return type(self)(self._array, newshape)
  101. class ImmutableDenseNDimArray(DenseNDimArray, ImmutableNDimArray): # type: ignore
  102. def __new__(cls, iterable, shape=None, **kwargs):
  103. return cls._new(iterable, shape, **kwargs)
  104. @classmethod
  105. def _new(cls, iterable, shape, **kwargs):
  106. shape, flat_list = cls._handle_ndarray_creation_inputs(iterable, shape, **kwargs)
  107. shape = Tuple(*map(_sympify, shape))
  108. cls._check_special_bounds(flat_list, shape)
  109. flat_list = flatten(flat_list)
  110. flat_list = Tuple(*flat_list)
  111. self = Basic.__new__(cls, flat_list, shape, **kwargs)
  112. self._shape = shape
  113. self._array = list(flat_list)
  114. self._rank = len(shape)
  115. self._loop_size = functools.reduce(lambda x,y: x*y, shape, 1)
  116. return self
  117. def __setitem__(self, index, value):
  118. raise TypeError('immutable N-dim array')
  119. def as_mutable(self):
  120. return MutableDenseNDimArray(self)
  121. def _eval_simplify(self, **kwargs):
  122. from sympy.simplify.simplify import simplify
  123. return self.applyfunc(simplify)
  124. class MutableDenseNDimArray(DenseNDimArray, MutableNDimArray):
  125. def __new__(cls, iterable=None, shape=None, **kwargs):
  126. return cls._new(iterable, shape, **kwargs)
  127. @classmethod
  128. def _new(cls, iterable, shape, **kwargs):
  129. shape, flat_list = cls._handle_ndarray_creation_inputs(iterable, shape, **kwargs)
  130. flat_list = flatten(flat_list)
  131. self = object.__new__(cls)
  132. self._shape = shape
  133. self._array = list(flat_list)
  134. self._rank = len(shape)
  135. self._loop_size = functools.reduce(lambda x,y: x*y, shape) if shape else len(flat_list)
  136. return self
  137. def __setitem__(self, index, value):
  138. """Allows to set items to MutableDenseNDimArray.
  139. Examples
  140. ========
  141. >>> from sympy import MutableDenseNDimArray
  142. >>> a = MutableDenseNDimArray.zeros(2, 2)
  143. >>> a[0,0] = 1
  144. >>> a[1,1] = 1
  145. >>> a
  146. [[1, 0], [0, 1]]
  147. """
  148. if isinstance(index, tuple) and any(isinstance(i, slice) for i in index):
  149. value, eindices, slice_offsets = self._get_slice_data_for_array_assignment(index, value)
  150. for i in eindices:
  151. other_i = [ind - j for ind, j in zip(i, slice_offsets) if j is not None]
  152. self._array[self._parse_index(i)] = value[other_i]
  153. else:
  154. index = self._parse_index(index)
  155. self._setter_iterable_check(value)
  156. value = _sympify(value)
  157. self._array[index] = value
  158. def as_immutable(self):
  159. return ImmutableDenseNDimArray(self)
  160. @property
  161. def free_symbols(self):
  162. return {i for j in self._array for i in j.free_symbols}