sparse_ndim_array.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. from sympy.core.basic import Basic
  2. from sympy.core.containers import (Dict, Tuple)
  3. from sympy.core.singleton import S
  4. from sympy.core.sympify import _sympify
  5. from sympy.tensor.array.mutable_ndim_array import MutableNDimArray
  6. from sympy.tensor.array.ndim_array import NDimArray, ImmutableNDimArray
  7. from sympy.utilities.iterables import flatten
  8. import functools
  9. class SparseNDimArray(NDimArray):
  10. def __new__(self, *args, **kwargs):
  11. return ImmutableSparseNDimArray(*args, **kwargs)
  12. def __getitem__(self, index):
  13. """
  14. Get an element from a sparse N-dim array.
  15. Examples
  16. ========
  17. >>> from sympy import MutableSparseNDimArray
  18. >>> a = MutableSparseNDimArray(range(4), (2, 2))
  19. >>> a
  20. [[0, 1], [2, 3]]
  21. >>> a[0, 0]
  22. 0
  23. >>> a[1, 1]
  24. 3
  25. >>> a[0]
  26. [0, 1]
  27. >>> a[1]
  28. [2, 3]
  29. Symbolic indexing:
  30. >>> from sympy.abc import i, j
  31. >>> a[i, j]
  32. [[0, 1], [2, 3]][i, j]
  33. Replace `i` and `j` to get element `(0, 0)`:
  34. >>> a[i, j].subs({i: 0, j: 0})
  35. 0
  36. """
  37. syindex = self._check_symbolic_index(index)
  38. if syindex is not None:
  39. return syindex
  40. index = self._check_index_for_getitem(index)
  41. # `index` is a tuple with one or more slices:
  42. if isinstance(index, tuple) and any(isinstance(i, slice) for i in index):
  43. sl_factors, eindices = self._get_slice_data_for_array_access(index)
  44. array = [self._sparse_array.get(self._parse_index(i), S.Zero) for i in eindices]
  45. nshape = [len(el) for i, el in enumerate(sl_factors) if isinstance(index[i], slice)]
  46. return type(self)(array, nshape)
  47. else:
  48. index = self._parse_index(index)
  49. return self._sparse_array.get(index, S.Zero)
  50. @classmethod
  51. def zeros(cls, *shape):
  52. """
  53. Return a sparse N-dim array of zeros.
  54. """
  55. return cls({}, shape)
  56. def tomatrix(self):
  57. """
  58. Converts MutableDenseNDimArray to Matrix. Can convert only 2-dim array, else will raise error.
  59. Examples
  60. ========
  61. >>> from sympy import MutableSparseNDimArray
  62. >>> a = MutableSparseNDimArray([1 for i in range(9)], (3, 3))
  63. >>> b = a.tomatrix()
  64. >>> b
  65. Matrix([
  66. [1, 1, 1],
  67. [1, 1, 1],
  68. [1, 1, 1]])
  69. """
  70. from sympy.matrices import SparseMatrix
  71. if self.rank() != 2:
  72. raise ValueError('Dimensions must be of size of 2')
  73. mat_sparse = {}
  74. for key, value in self._sparse_array.items():
  75. mat_sparse[self._get_tuple_index(key)] = value
  76. return SparseMatrix(self.shape[0], self.shape[1], mat_sparse)
  77. def reshape(self, *newshape):
  78. new_total_size = functools.reduce(lambda x,y: x*y, newshape)
  79. if new_total_size != self._loop_size:
  80. raise ValueError("Invalid reshape parameters " + newshape)
  81. return type(self)(self._sparse_array, newshape)
  82. class ImmutableSparseNDimArray(SparseNDimArray, ImmutableNDimArray): # type: ignore
  83. def __new__(cls, iterable=None, shape=None, **kwargs):
  84. shape, flat_list = cls._handle_ndarray_creation_inputs(iterable, shape, **kwargs)
  85. shape = Tuple(*map(_sympify, shape))
  86. cls._check_special_bounds(flat_list, shape)
  87. loop_size = functools.reduce(lambda x,y: x*y, shape) if shape else len(flat_list)
  88. # Sparse array:
  89. if isinstance(flat_list, (dict, Dict)):
  90. sparse_array = Dict(flat_list)
  91. else:
  92. sparse_array = {}
  93. for i, el in enumerate(flatten(flat_list)):
  94. if el != 0:
  95. sparse_array[i] = _sympify(el)
  96. sparse_array = Dict(sparse_array)
  97. self = Basic.__new__(cls, sparse_array, shape, **kwargs)
  98. self._shape = shape
  99. self._rank = len(shape)
  100. self._loop_size = loop_size
  101. self._sparse_array = sparse_array
  102. return self
  103. def __setitem__(self, index, value):
  104. raise TypeError("immutable N-dim array")
  105. def as_mutable(self):
  106. return MutableSparseNDimArray(self)
  107. class MutableSparseNDimArray(MutableNDimArray, SparseNDimArray):
  108. def __new__(cls, iterable=None, shape=None, **kwargs):
  109. shape, flat_list = cls._handle_ndarray_creation_inputs(iterable, shape, **kwargs)
  110. self = object.__new__(cls)
  111. self._shape = shape
  112. self._rank = len(shape)
  113. self._loop_size = functools.reduce(lambda x,y: x*y, shape) if shape else len(flat_list)
  114. # Sparse array:
  115. if isinstance(flat_list, (dict, Dict)):
  116. self._sparse_array = dict(flat_list)
  117. return self
  118. self._sparse_array = {}
  119. for i, el in enumerate(flatten(flat_list)):
  120. if el != 0:
  121. self._sparse_array[i] = _sympify(el)
  122. return self
  123. def __setitem__(self, index, value):
  124. """Allows to set items to MutableDenseNDimArray.
  125. Examples
  126. ========
  127. >>> from sympy import MutableSparseNDimArray
  128. >>> a = MutableSparseNDimArray.zeros(2, 2)
  129. >>> a[0, 0] = 1
  130. >>> a[1, 1] = 1
  131. >>> a
  132. [[1, 0], [0, 1]]
  133. """
  134. if isinstance(index, tuple) and any(isinstance(i, slice) for i in index):
  135. value, eindices, slice_offsets = self._get_slice_data_for_array_assignment(index, value)
  136. for i in eindices:
  137. other_i = [ind - j for ind, j in zip(i, slice_offsets) if j is not None]
  138. other_value = value[other_i]
  139. complete_index = self._parse_index(i)
  140. if other_value != 0:
  141. self._sparse_array[complete_index] = other_value
  142. elif complete_index in self._sparse_array:
  143. self._sparse_array.pop(complete_index)
  144. else:
  145. index = self._parse_index(index)
  146. value = _sympify(value)
  147. if value == 0 and index in self._sparse_array:
  148. self._sparse_array.pop(index)
  149. else:
  150. self._sparse_array[index] = value
  151. def as_immutable(self):
  152. return ImmutableSparseNDimArray(self)
  153. @property
  154. def free_symbols(self):
  155. return {i for j in self._sparse_array.values() for i in j.free_symbols}