ndim_array.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600
  1. from sympy.core.basic import Basic
  2. from sympy.core.containers import (Dict, Tuple)
  3. from sympy.core.expr import Expr
  4. from sympy.core.kind import Kind, NumberKind, UndefinedKind
  5. from sympy.core.numbers import Integer
  6. from sympy.core.singleton import S
  7. from sympy.core.sympify import sympify
  8. from sympy.external.gmpy import SYMPY_INTS
  9. from sympy.printing.defaults import Printable
  10. import itertools
  11. from collections.abc import Iterable
  12. class ArrayKind(Kind):
  13. """
  14. Kind for N-dimensional array in SymPy.
  15. This kind represents the multidimensional array that algebraic
  16. operations are defined. Basic class for this kind is ``NDimArray``,
  17. but any expression representing the array can have this.
  18. Parameters
  19. ==========
  20. element_kind : Kind
  21. Kind of the element. Default is :obj:NumberKind `<sympy.core.kind.NumberKind>`,
  22. which means that the array contains only numbers.
  23. Examples
  24. ========
  25. Any instance of array class has ``ArrayKind``.
  26. >>> from sympy import NDimArray
  27. >>> NDimArray([1,2,3]).kind
  28. ArrayKind(NumberKind)
  29. Although expressions representing an array may be not instance of
  30. array class, it will have ``ArrayKind`` as well.
  31. >>> from sympy import Integral
  32. >>> from sympy.tensor.array import NDimArray
  33. >>> from sympy.abc import x
  34. >>> intA = Integral(NDimArray([1,2,3]), x)
  35. >>> isinstance(intA, NDimArray)
  36. False
  37. >>> intA.kind
  38. ArrayKind(NumberKind)
  39. Use ``isinstance()`` to check for ``ArrayKind` without specifying
  40. the element kind. Use ``is`` with specifying the element kind.
  41. >>> from sympy.tensor.array import ArrayKind
  42. >>> from sympy.core import NumberKind
  43. >>> boolA = NDimArray([True, False])
  44. >>> isinstance(boolA.kind, ArrayKind)
  45. True
  46. >>> boolA.kind is ArrayKind(NumberKind)
  47. False
  48. See Also
  49. ========
  50. shape : Function to return the shape of objects with ``MatrixKind``.
  51. """
  52. def __new__(cls, element_kind=NumberKind):
  53. obj = super().__new__(cls, element_kind)
  54. obj.element_kind = element_kind
  55. return obj
  56. def __repr__(self):
  57. return "ArrayKind(%s)" % self.element_kind
  58. @classmethod
  59. def _union(cls, kinds) -> 'ArrayKind':
  60. elem_kinds = {e.kind for e in kinds}
  61. if len(elem_kinds) == 1:
  62. elemkind, = elem_kinds
  63. else:
  64. elemkind = UndefinedKind
  65. return ArrayKind(elemkind)
  66. class NDimArray(Printable):
  67. """N-dimensional array.
  68. Examples
  69. ========
  70. Create an N-dim array of zeros:
  71. >>> from sympy import MutableDenseNDimArray
  72. >>> a = MutableDenseNDimArray.zeros(2, 3, 4)
  73. >>> a
  74. [[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]
  75. Create an N-dim array from a list;
  76. >>> a = MutableDenseNDimArray([[2, 3], [4, 5]])
  77. >>> a
  78. [[2, 3], [4, 5]]
  79. >>> b = MutableDenseNDimArray([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]])
  80. >>> b
  81. [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]
  82. Create an N-dim array from a flat list with dimension shape:
  83. >>> a = MutableDenseNDimArray([1, 2, 3, 4, 5, 6], (2, 3))
  84. >>> a
  85. [[1, 2, 3], [4, 5, 6]]
  86. Create an N-dim array from a matrix:
  87. >>> from sympy import Matrix
  88. >>> a = Matrix([[1,2],[3,4]])
  89. >>> a
  90. Matrix([
  91. [1, 2],
  92. [3, 4]])
  93. >>> b = MutableDenseNDimArray(a)
  94. >>> b
  95. [[1, 2], [3, 4]]
  96. Arithmetic operations on N-dim arrays
  97. >>> a = MutableDenseNDimArray([1, 1, 1, 1], (2, 2))
  98. >>> b = MutableDenseNDimArray([4, 4, 4, 4], (2, 2))
  99. >>> c = a + b
  100. >>> c
  101. [[5, 5], [5, 5]]
  102. >>> a - b
  103. [[-3, -3], [-3, -3]]
  104. """
  105. _diff_wrt = True
  106. is_scalar = False
  107. def __new__(cls, iterable, shape=None, **kwargs):
  108. from sympy.tensor.array import ImmutableDenseNDimArray
  109. return ImmutableDenseNDimArray(iterable, shape, **kwargs)
  110. def __getitem__(self, index):
  111. raise NotImplementedError("A subclass of NDimArray should implement __getitem__")
  112. def _parse_index(self, index):
  113. if isinstance(index, (SYMPY_INTS, Integer)):
  114. if index >= self._loop_size:
  115. raise ValueError("Only a tuple index is accepted")
  116. return index
  117. if self._loop_size == 0:
  118. raise ValueError("Index not valid with an empty array")
  119. if len(index) != self._rank:
  120. raise ValueError('Wrong number of array axes')
  121. real_index = 0
  122. # check if input index can exist in current indexing
  123. for i in range(self._rank):
  124. if (index[i] >= self.shape[i]) or (index[i] < -self.shape[i]):
  125. raise ValueError('Index ' + str(index) + ' out of border')
  126. if index[i] < 0:
  127. real_index += 1
  128. real_index = real_index*self.shape[i] + index[i]
  129. return real_index
  130. def _get_tuple_index(self, integer_index):
  131. index = []
  132. for i, sh in enumerate(reversed(self.shape)):
  133. index.append(integer_index % sh)
  134. integer_index //= sh
  135. index.reverse()
  136. return tuple(index)
  137. def _check_symbolic_index(self, index):
  138. # Check if any index is symbolic:
  139. tuple_index = (index if isinstance(index, tuple) else (index,))
  140. if any((isinstance(i, Expr) and (not i.is_number)) for i in tuple_index):
  141. for i, nth_dim in zip(tuple_index, self.shape):
  142. if ((i < 0) == True) or ((i >= nth_dim) == True):
  143. raise ValueError("index out of range")
  144. from sympy.tensor import Indexed
  145. return Indexed(self, *tuple_index)
  146. return None
  147. def _setter_iterable_check(self, value):
  148. from sympy.matrices.matrices import MatrixBase
  149. if isinstance(value, (Iterable, MatrixBase, NDimArray)):
  150. raise NotImplementedError
  151. @classmethod
  152. def _scan_iterable_shape(cls, iterable):
  153. def f(pointer):
  154. if not isinstance(pointer, Iterable):
  155. return [pointer], ()
  156. if len(pointer) == 0:
  157. return [], (0,)
  158. result = []
  159. elems, shapes = zip(*[f(i) for i in pointer])
  160. if len(set(shapes)) != 1:
  161. raise ValueError("could not determine shape unambiguously")
  162. for i in elems:
  163. result.extend(i)
  164. return result, (len(shapes),)+shapes[0]
  165. return f(iterable)
  166. @classmethod
  167. def _handle_ndarray_creation_inputs(cls, iterable=None, shape=None, **kwargs):
  168. from sympy.matrices.matrices import MatrixBase
  169. from sympy.tensor.array import SparseNDimArray
  170. if shape is None:
  171. if iterable is None:
  172. shape = ()
  173. iterable = ()
  174. # Construction of a sparse array from a sparse array
  175. elif isinstance(iterable, SparseNDimArray):
  176. return iterable._shape, iterable._sparse_array
  177. # Construct N-dim array from another N-dim array:
  178. elif isinstance(iterable, NDimArray):
  179. shape = iterable.shape
  180. # Construct N-dim array from an iterable (numpy arrays included):
  181. elif isinstance(iterable, Iterable):
  182. iterable, shape = cls._scan_iterable_shape(iterable)
  183. # Construct N-dim array from a Matrix:
  184. elif isinstance(iterable, MatrixBase):
  185. shape = iterable.shape
  186. else:
  187. shape = ()
  188. iterable = (iterable,)
  189. if isinstance(iterable, (Dict, dict)) and shape is not None:
  190. new_dict = iterable.copy()
  191. for k, v in new_dict.items():
  192. if isinstance(k, (tuple, Tuple)):
  193. new_key = 0
  194. for i, idx in enumerate(k):
  195. new_key = new_key * shape[i] + idx
  196. iterable[new_key] = iterable[k]
  197. del iterable[k]
  198. if isinstance(shape, (SYMPY_INTS, Integer)):
  199. shape = (shape,)
  200. if not all(isinstance(dim, (SYMPY_INTS, Integer)) for dim in shape):
  201. raise TypeError("Shape should contain integers only.")
  202. return tuple(shape), iterable
  203. def __len__(self):
  204. """Overload common function len(). Returns number of elements in array.
  205. Examples
  206. ========
  207. >>> from sympy import MutableDenseNDimArray
  208. >>> a = MutableDenseNDimArray.zeros(3, 3)
  209. >>> a
  210. [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
  211. >>> len(a)
  212. 9
  213. """
  214. return self._loop_size
  215. @property
  216. def shape(self):
  217. """
  218. Returns array shape (dimension).
  219. Examples
  220. ========
  221. >>> from sympy import MutableDenseNDimArray
  222. >>> a = MutableDenseNDimArray.zeros(3, 3)
  223. >>> a.shape
  224. (3, 3)
  225. """
  226. return self._shape
  227. def rank(self):
  228. """
  229. Returns rank of array.
  230. Examples
  231. ========
  232. >>> from sympy import MutableDenseNDimArray
  233. >>> a = MutableDenseNDimArray.zeros(3,4,5,6,3)
  234. >>> a.rank()
  235. 5
  236. """
  237. return self._rank
  238. def diff(self, *args, **kwargs):
  239. """
  240. Calculate the derivative of each element in the array.
  241. Examples
  242. ========
  243. >>> from sympy import ImmutableDenseNDimArray
  244. >>> from sympy.abc import x, y
  245. >>> M = ImmutableDenseNDimArray([[x, y], [1, x*y]])
  246. >>> M.diff(x)
  247. [[1, 0], [0, y]]
  248. """
  249. from sympy.tensor.array.array_derivatives import ArrayDerivative
  250. kwargs.setdefault('evaluate', True)
  251. return ArrayDerivative(self.as_immutable(), *args, **kwargs)
  252. def _eval_derivative(self, base):
  253. # Types are (base: scalar, self: array)
  254. return self.applyfunc(lambda x: base.diff(x))
  255. def _eval_derivative_n_times(self, s, n):
  256. return Basic._eval_derivative_n_times(self, s, n)
  257. def applyfunc(self, f):
  258. """Apply a function to each element of the N-dim array.
  259. Examples
  260. ========
  261. >>> from sympy import ImmutableDenseNDimArray
  262. >>> m = ImmutableDenseNDimArray([i*2+j for i in range(2) for j in range(2)], (2, 2))
  263. >>> m
  264. [[0, 1], [2, 3]]
  265. >>> m.applyfunc(lambda i: 2*i)
  266. [[0, 2], [4, 6]]
  267. """
  268. from sympy.tensor.array import SparseNDimArray
  269. from sympy.tensor.array.arrayop import Flatten
  270. if isinstance(self, SparseNDimArray) and f(S.Zero) == 0:
  271. return type(self)({k: f(v) for k, v in self._sparse_array.items() if f(v) != 0}, self.shape)
  272. return type(self)(map(f, Flatten(self)), self.shape)
  273. def _sympystr(self, printer):
  274. def f(sh, shape_left, i, j):
  275. if len(shape_left) == 1:
  276. return "["+", ".join([printer._print(self[self._get_tuple_index(e)]) for e in range(i, j)])+"]"
  277. sh //= shape_left[0]
  278. return "[" + ", ".join([f(sh, shape_left[1:], i+e*sh, i+(e+1)*sh) for e in range(shape_left[0])]) + "]" # + "\n"*len(shape_left)
  279. if self.rank() == 0:
  280. return printer._print(self[()])
  281. return f(self._loop_size, self.shape, 0, self._loop_size)
  282. def tolist(self):
  283. """
  284. Converting MutableDenseNDimArray to one-dim list
  285. Examples
  286. ========
  287. >>> from sympy import MutableDenseNDimArray
  288. >>> a = MutableDenseNDimArray([1, 2, 3, 4], (2, 2))
  289. >>> a
  290. [[1, 2], [3, 4]]
  291. >>> b = a.tolist()
  292. >>> b
  293. [[1, 2], [3, 4]]
  294. """
  295. def f(sh, shape_left, i, j):
  296. if len(shape_left) == 1:
  297. return [self[self._get_tuple_index(e)] for e in range(i, j)]
  298. result = []
  299. sh //= shape_left[0]
  300. for e in range(shape_left[0]):
  301. result.append(f(sh, shape_left[1:], i+e*sh, i+(e+1)*sh))
  302. return result
  303. return f(self._loop_size, self.shape, 0, self._loop_size)
  304. def __add__(self, other):
  305. from sympy.tensor.array.arrayop import Flatten
  306. if not isinstance(other, NDimArray):
  307. return NotImplemented
  308. if self.shape != other.shape:
  309. raise ValueError("array shape mismatch")
  310. result_list = [i+j for i,j in zip(Flatten(self), Flatten(other))]
  311. return type(self)(result_list, self.shape)
  312. def __sub__(self, other):
  313. from sympy.tensor.array.arrayop import Flatten
  314. if not isinstance(other, NDimArray):
  315. return NotImplemented
  316. if self.shape != other.shape:
  317. raise ValueError("array shape mismatch")
  318. result_list = [i-j for i,j in zip(Flatten(self), Flatten(other))]
  319. return type(self)(result_list, self.shape)
  320. def __mul__(self, other):
  321. from sympy.matrices.matrices import MatrixBase
  322. from sympy.tensor.array import SparseNDimArray
  323. from sympy.tensor.array.arrayop import Flatten
  324. if isinstance(other, (Iterable, NDimArray, MatrixBase)):
  325. raise ValueError("scalar expected, use tensorproduct(...) for tensorial product")
  326. other = sympify(other)
  327. if isinstance(self, SparseNDimArray):
  328. if other.is_zero:
  329. return type(self)({}, self.shape)
  330. return type(self)({k: other*v for (k, v) in self._sparse_array.items()}, self.shape)
  331. result_list = [i*other for i in Flatten(self)]
  332. return type(self)(result_list, self.shape)
  333. def __rmul__(self, other):
  334. from sympy.matrices.matrices import MatrixBase
  335. from sympy.tensor.array import SparseNDimArray
  336. from sympy.tensor.array.arrayop import Flatten
  337. if isinstance(other, (Iterable, NDimArray, MatrixBase)):
  338. raise ValueError("scalar expected, use tensorproduct(...) for tensorial product")
  339. other = sympify(other)
  340. if isinstance(self, SparseNDimArray):
  341. if other.is_zero:
  342. return type(self)({}, self.shape)
  343. return type(self)({k: other*v for (k, v) in self._sparse_array.items()}, self.shape)
  344. result_list = [other*i for i in Flatten(self)]
  345. return type(self)(result_list, self.shape)
  346. def __truediv__(self, other):
  347. from sympy.matrices.matrices import MatrixBase
  348. from sympy.tensor.array import SparseNDimArray
  349. from sympy.tensor.array.arrayop import Flatten
  350. if isinstance(other, (Iterable, NDimArray, MatrixBase)):
  351. raise ValueError("scalar expected")
  352. other = sympify(other)
  353. if isinstance(self, SparseNDimArray) and other != S.Zero:
  354. return type(self)({k: v/other for (k, v) in self._sparse_array.items()}, self.shape)
  355. result_list = [i/other for i in Flatten(self)]
  356. return type(self)(result_list, self.shape)
  357. def __rtruediv__(self, other):
  358. raise NotImplementedError('unsupported operation on NDimArray')
  359. def __neg__(self):
  360. from sympy.tensor.array import SparseNDimArray
  361. from sympy.tensor.array.arrayop import Flatten
  362. if isinstance(self, SparseNDimArray):
  363. return type(self)({k: -v for (k, v) in self._sparse_array.items()}, self.shape)
  364. result_list = [-i for i in Flatten(self)]
  365. return type(self)(result_list, self.shape)
  366. def __iter__(self):
  367. def iterator():
  368. if self._shape:
  369. for i in range(self._shape[0]):
  370. yield self[i]
  371. else:
  372. yield self[()]
  373. return iterator()
  374. def __eq__(self, other):
  375. """
  376. NDimArray instances can be compared to each other.
  377. Instances equal if they have same shape and data.
  378. Examples
  379. ========
  380. >>> from sympy import MutableDenseNDimArray
  381. >>> a = MutableDenseNDimArray.zeros(2, 3)
  382. >>> b = MutableDenseNDimArray.zeros(2, 3)
  383. >>> a == b
  384. True
  385. >>> c = a.reshape(3, 2)
  386. >>> c == b
  387. False
  388. >>> a[0,0] = 1
  389. >>> b[0,0] = 2
  390. >>> a == b
  391. False
  392. """
  393. from sympy.tensor.array import SparseNDimArray
  394. if not isinstance(other, NDimArray):
  395. return False
  396. if not self.shape == other.shape:
  397. return False
  398. if isinstance(self, SparseNDimArray) and isinstance(other, SparseNDimArray):
  399. return dict(self._sparse_array) == dict(other._sparse_array)
  400. return list(self) == list(other)
  401. def __ne__(self, other):
  402. return not self == other
  403. def _eval_transpose(self):
  404. if self.rank() != 2:
  405. raise ValueError("array rank not 2")
  406. from .arrayop import permutedims
  407. return permutedims(self, (1, 0))
  408. def transpose(self):
  409. return self._eval_transpose()
  410. def _eval_conjugate(self):
  411. from sympy.tensor.array.arrayop import Flatten
  412. return self.func([i.conjugate() for i in Flatten(self)], self.shape)
  413. def conjugate(self):
  414. return self._eval_conjugate()
  415. def _eval_adjoint(self):
  416. return self.transpose().conjugate()
  417. def adjoint(self):
  418. return self._eval_adjoint()
  419. def _slice_expand(self, s, dim):
  420. if not isinstance(s, slice):
  421. return (s,)
  422. start, stop, step = s.indices(dim)
  423. return [start + i*step for i in range((stop-start)//step)]
  424. def _get_slice_data_for_array_access(self, index):
  425. sl_factors = [self._slice_expand(i, dim) for (i, dim) in zip(index, self.shape)]
  426. eindices = itertools.product(*sl_factors)
  427. return sl_factors, eindices
  428. def _get_slice_data_for_array_assignment(self, index, value):
  429. if not isinstance(value, NDimArray):
  430. value = type(self)(value)
  431. sl_factors, eindices = self._get_slice_data_for_array_access(index)
  432. slice_offsets = [min(i) if isinstance(i, list) else None for i in sl_factors]
  433. # TODO: add checks for dimensions for `value`?
  434. return value, eindices, slice_offsets
  435. @classmethod
  436. def _check_special_bounds(cls, flat_list, shape):
  437. if shape == () and len(flat_list) != 1:
  438. raise ValueError("arrays without shape need one scalar value")
  439. if shape == (0,) and len(flat_list) > 0:
  440. raise ValueError("if array shape is (0,) there cannot be elements")
  441. def _check_index_for_getitem(self, index):
  442. if isinstance(index, (SYMPY_INTS, Integer, slice)):
  443. index = (index,)
  444. if len(index) < self.rank():
  445. index = tuple(index) + \
  446. tuple(slice(None) for i in range(len(index), self.rank()))
  447. if len(index) > self.rank():
  448. raise ValueError('Dimension of index greater than rank of array')
  449. return index
  450. class ImmutableNDimArray(NDimArray, Basic):
  451. _op_priority = 11.0
  452. def __hash__(self):
  453. return Basic.__hash__(self)
  454. def as_immutable(self):
  455. return self
  456. def as_mutable(self):
  457. raise NotImplementedError("abstract method")