array_comprehension.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. import functools, itertools
  2. from sympy.core.sympify import _sympify, sympify
  3. from sympy.core.expr import Expr
  4. from sympy.core import Basic, Tuple
  5. from sympy.tensor.array import ImmutableDenseNDimArray
  6. from sympy.core.symbol import Symbol
  7. from sympy.core.numbers import Integer
  8. class ArrayComprehension(Basic):
  9. """
  10. Generate a list comprehension.
  11. Explanation
  12. ===========
  13. If there is a symbolic dimension, for example, say [i for i in range(1, N)] where
  14. N is a Symbol, then the expression will not be expanded to an array. Otherwise,
  15. calling the doit() function will launch the expansion.
  16. Examples
  17. ========
  18. >>> from sympy.tensor.array import ArrayComprehension
  19. >>> from sympy import symbols
  20. >>> i, j, k = symbols('i j k')
  21. >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
  22. >>> a
  23. ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
  24. >>> a.doit()
  25. [[11, 12, 13], [21, 22, 23], [31, 32, 33], [41, 42, 43]]
  26. >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k))
  27. >>> b.doit()
  28. ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k))
  29. """
  30. def __new__(cls, function, *symbols, **assumptions):
  31. if any(len(l) != 3 or None for l in symbols):
  32. raise ValueError('ArrayComprehension requires values lower and upper bound'
  33. ' for the expression')
  34. arglist = [sympify(function)]
  35. arglist.extend(cls._check_limits_validity(function, symbols))
  36. obj = Basic.__new__(cls, *arglist, **assumptions)
  37. obj._limits = obj._args[1:]
  38. obj._shape = cls._calculate_shape_from_limits(obj._limits)
  39. obj._rank = len(obj._shape)
  40. obj._loop_size = cls._calculate_loop_size(obj._shape)
  41. return obj
  42. @property
  43. def function(self):
  44. """The function applied across limits.
  45. Examples
  46. ========
  47. >>> from sympy.tensor.array import ArrayComprehension
  48. >>> from sympy import symbols
  49. >>> i, j = symbols('i j')
  50. >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
  51. >>> a.function
  52. 10*i + j
  53. """
  54. return self._args[0]
  55. @property
  56. def limits(self):
  57. """
  58. The list of limits that will be applied while expanding the array.
  59. Examples
  60. ========
  61. >>> from sympy.tensor.array import ArrayComprehension
  62. >>> from sympy import symbols
  63. >>> i, j = symbols('i j')
  64. >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
  65. >>> a.limits
  66. ((i, 1, 4), (j, 1, 3))
  67. """
  68. return self._limits
  69. @property
  70. def free_symbols(self):
  71. """
  72. The set of the free_symbols in the array.
  73. Variables appeared in the bounds are supposed to be excluded
  74. from the free symbol set.
  75. Examples
  76. ========
  77. >>> from sympy.tensor.array import ArrayComprehension
  78. >>> from sympy import symbols
  79. >>> i, j, k = symbols('i j k')
  80. >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
  81. >>> a.free_symbols
  82. set()
  83. >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k+3))
  84. >>> b.free_symbols
  85. {k}
  86. """
  87. expr_free_sym = self.function.free_symbols
  88. for var, inf, sup in self._limits:
  89. expr_free_sym.discard(var)
  90. curr_free_syms = inf.free_symbols.union(sup.free_symbols)
  91. expr_free_sym = expr_free_sym.union(curr_free_syms)
  92. return expr_free_sym
  93. @property
  94. def variables(self):
  95. """The tuples of the variables in the limits.
  96. Examples
  97. ========
  98. >>> from sympy.tensor.array import ArrayComprehension
  99. >>> from sympy import symbols
  100. >>> i, j, k = symbols('i j k')
  101. >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
  102. >>> a.variables
  103. [i, j]
  104. """
  105. return [l[0] for l in self._limits]
  106. @property
  107. def bound_symbols(self):
  108. """The list of dummy variables.
  109. Note
  110. ====
  111. Note that all variables are dummy variables since a limit without
  112. lower bound or upper bound is not accepted.
  113. """
  114. return [l[0] for l in self._limits if len(l) != 1]
  115. @property
  116. def shape(self):
  117. """
  118. The shape of the expanded array, which may have symbols.
  119. Note
  120. ====
  121. Both the lower and the upper bounds are included while
  122. calculating the shape.
  123. Examples
  124. ========
  125. >>> from sympy.tensor.array import ArrayComprehension
  126. >>> from sympy import symbols
  127. >>> i, j, k = symbols('i j k')
  128. >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
  129. >>> a.shape
  130. (4, 3)
  131. >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k+3))
  132. >>> b.shape
  133. (4, k + 3)
  134. """
  135. return self._shape
  136. @property
  137. def is_shape_numeric(self):
  138. """
  139. Test if the array is shape-numeric which means there is no symbolic
  140. dimension.
  141. Examples
  142. ========
  143. >>> from sympy.tensor.array import ArrayComprehension
  144. >>> from sympy import symbols
  145. >>> i, j, k = symbols('i j k')
  146. >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
  147. >>> a.is_shape_numeric
  148. True
  149. >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k+3))
  150. >>> b.is_shape_numeric
  151. False
  152. """
  153. for _, inf, sup in self._limits:
  154. if Basic(inf, sup).atoms(Symbol):
  155. return False
  156. return True
  157. def rank(self):
  158. """The rank of the expanded array.
  159. Examples
  160. ========
  161. >>> from sympy.tensor.array import ArrayComprehension
  162. >>> from sympy import symbols
  163. >>> i, j, k = symbols('i j k')
  164. >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
  165. >>> a.rank()
  166. 2
  167. """
  168. return self._rank
  169. def __len__(self):
  170. """
  171. The length of the expanded array which means the number
  172. of elements in the array.
  173. Raises
  174. ======
  175. ValueError : When the length of the array is symbolic
  176. Examples
  177. ========
  178. >>> from sympy.tensor.array import ArrayComprehension
  179. >>> from sympy import symbols
  180. >>> i, j = symbols('i j')
  181. >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
  182. >>> len(a)
  183. 12
  184. """
  185. if self._loop_size.free_symbols:
  186. raise ValueError('Symbolic length is not supported')
  187. return self._loop_size
  188. @classmethod
  189. def _check_limits_validity(cls, function, limits):
  190. #limits = sympify(limits)
  191. new_limits = []
  192. for var, inf, sup in limits:
  193. var = _sympify(var)
  194. inf = _sympify(inf)
  195. #since this is stored as an argument, it should be
  196. #a Tuple
  197. if isinstance(sup, list):
  198. sup = Tuple(*sup)
  199. else:
  200. sup = _sympify(sup)
  201. new_limits.append(Tuple(var, inf, sup))
  202. if any((not isinstance(i, Expr)) or i.atoms(Symbol, Integer) != i.atoms()
  203. for i in [inf, sup]):
  204. raise TypeError('Bounds should be an Expression(combination of Integer and Symbol)')
  205. if (inf > sup) == True:
  206. raise ValueError('Lower bound should be inferior to upper bound')
  207. if var in inf.free_symbols or var in sup.free_symbols:
  208. raise ValueError('Variable should not be part of its bounds')
  209. return new_limits
  210. @classmethod
  211. def _calculate_shape_from_limits(cls, limits):
  212. return tuple([sup - inf + 1 for _, inf, sup in limits])
  213. @classmethod
  214. def _calculate_loop_size(cls, shape):
  215. if not shape:
  216. return 0
  217. loop_size = 1
  218. for l in shape:
  219. loop_size = loop_size * l
  220. return loop_size
  221. def doit(self, **hints):
  222. if not self.is_shape_numeric:
  223. return self
  224. return self._expand_array()
  225. def _expand_array(self):
  226. res = []
  227. for values in itertools.product(*[range(inf, sup+1)
  228. for var, inf, sup
  229. in self._limits]):
  230. res.append(self._get_element(values))
  231. return ImmutableDenseNDimArray(res, self.shape)
  232. def _get_element(self, values):
  233. temp = self.function
  234. for var, val in zip(self.variables, values):
  235. temp = temp.subs(var, val)
  236. return temp
  237. def tolist(self):
  238. """Transform the expanded array to a list.
  239. Raises
  240. ======
  241. ValueError : When there is a symbolic dimension
  242. Examples
  243. ========
  244. >>> from sympy.tensor.array import ArrayComprehension
  245. >>> from sympy import symbols
  246. >>> i, j = symbols('i j')
  247. >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
  248. >>> a.tolist()
  249. [[11, 12, 13], [21, 22, 23], [31, 32, 33], [41, 42, 43]]
  250. """
  251. if self.is_shape_numeric:
  252. return self._expand_array().tolist()
  253. raise ValueError("A symbolic array cannot be expanded to a list")
  254. def tomatrix(self):
  255. """Transform the expanded array to a matrix.
  256. Raises
  257. ======
  258. ValueError : When there is a symbolic dimension
  259. ValueError : When the rank of the expanded array is not equal to 2
  260. Examples
  261. ========
  262. >>> from sympy.tensor.array import ArrayComprehension
  263. >>> from sympy import symbols
  264. >>> i, j = symbols('i j')
  265. >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
  266. >>> a.tomatrix()
  267. Matrix([
  268. [11, 12, 13],
  269. [21, 22, 23],
  270. [31, 32, 33],
  271. [41, 42, 43]])
  272. """
  273. from sympy.matrices import Matrix
  274. if not self.is_shape_numeric:
  275. raise ValueError("A symbolic array cannot be expanded to a matrix")
  276. if self._rank != 2:
  277. raise ValueError('Dimensions must be of size of 2')
  278. return Matrix(self._expand_array().tomatrix())
  279. def isLambda(v):
  280. LAMBDA = lambda: 0
  281. return isinstance(v, type(LAMBDA)) and v.__name__ == LAMBDA.__name__
  282. class ArrayComprehensionMap(ArrayComprehension):
  283. '''
  284. A subclass of ArrayComprehension dedicated to map external function lambda.
  285. Notes
  286. =====
  287. Only the lambda function is considered.
  288. At most one argument in lambda function is accepted in order to avoid ambiguity
  289. in value assignment.
  290. Examples
  291. ========
  292. >>> from sympy.tensor.array import ArrayComprehensionMap
  293. >>> from sympy import symbols
  294. >>> i, j, k = symbols('i j k')
  295. >>> a = ArrayComprehensionMap(lambda: 1, (i, 1, 4))
  296. >>> a.doit()
  297. [1, 1, 1, 1]
  298. >>> b = ArrayComprehensionMap(lambda a: a+1, (j, 1, 4))
  299. >>> b.doit()
  300. [2, 3, 4, 5]
  301. '''
  302. def __new__(cls, function, *symbols, **assumptions):
  303. if any(len(l) != 3 or None for l in symbols):
  304. raise ValueError('ArrayComprehension requires values lower and upper bound'
  305. ' for the expression')
  306. if not isLambda(function):
  307. raise ValueError('Data type not supported')
  308. arglist = cls._check_limits_validity(function, symbols)
  309. obj = Basic.__new__(cls, *arglist, **assumptions)
  310. obj._limits = obj._args
  311. obj._shape = cls._calculate_shape_from_limits(obj._limits)
  312. obj._rank = len(obj._shape)
  313. obj._loop_size = cls._calculate_loop_size(obj._shape)
  314. obj._lambda = function
  315. return obj
  316. @property
  317. def func(self):
  318. class _(ArrayComprehensionMap):
  319. def __new__(cls, *args, **kwargs):
  320. return ArrayComprehensionMap(self._lambda, *args, **kwargs)
  321. return _
  322. def _get_element(self, values):
  323. temp = self._lambda
  324. if self._lambda.__code__.co_argcount == 0:
  325. temp = temp()
  326. elif self._lambda.__code__.co_argcount == 1:
  327. temp = temp(functools.reduce(lambda a, b: a*b, values))
  328. return temp