hadamard.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. from collections import Counter
  2. from sympy.core import Mul, sympify
  3. from sympy.core.add import Add
  4. from sympy.core.expr import ExprBuilder
  5. from sympy.core.sorting import default_sort_key
  6. from sympy.functions.elementary.exponential import log
  7. from sympy.matrices.expressions.matexpr import MatrixExpr
  8. from sympy.matrices.expressions._shape import validate_matadd_integer as validate
  9. from sympy.matrices.expressions.special import ZeroMatrix, OneMatrix
  10. from sympy.strategies import (
  11. unpack, flatten, condition, exhaust, rm_id, sort
  12. )
  13. from sympy.utilities.exceptions import sympy_deprecation_warning
  14. def hadamard_product(*matrices):
  15. """
  16. Return the elementwise (aka Hadamard) product of matrices.
  17. Examples
  18. ========
  19. >>> from sympy import hadamard_product, MatrixSymbol
  20. >>> A = MatrixSymbol('A', 2, 3)
  21. >>> B = MatrixSymbol('B', 2, 3)
  22. >>> hadamard_product(A)
  23. A
  24. >>> hadamard_product(A, B)
  25. HadamardProduct(A, B)
  26. >>> hadamard_product(A, B)[0, 1]
  27. A[0, 1]*B[0, 1]
  28. """
  29. if not matrices:
  30. raise TypeError("Empty Hadamard product is undefined")
  31. if len(matrices) == 1:
  32. return matrices[0]
  33. return HadamardProduct(*matrices).doit()
  34. class HadamardProduct(MatrixExpr):
  35. """
  36. Elementwise product of matrix expressions
  37. Examples
  38. ========
  39. Hadamard product for matrix symbols:
  40. >>> from sympy import hadamard_product, HadamardProduct, MatrixSymbol
  41. >>> A = MatrixSymbol('A', 5, 5)
  42. >>> B = MatrixSymbol('B', 5, 5)
  43. >>> isinstance(hadamard_product(A, B), HadamardProduct)
  44. True
  45. Notes
  46. =====
  47. This is a symbolic object that simply stores its argument without
  48. evaluating it. To actually compute the product, use the function
  49. ``hadamard_product()`` or ``HadamardProduct.doit``
  50. """
  51. is_HadamardProduct = True
  52. def __new__(cls, *args, evaluate=False, check=None):
  53. args = list(map(sympify, args))
  54. if len(args) == 0:
  55. # We currently don't have a way to support one-matrices of generic dimensions:
  56. raise ValueError("HadamardProduct needs at least one argument")
  57. if not all(isinstance(arg, MatrixExpr) for arg in args):
  58. raise TypeError("Mix of Matrix and Scalar symbols")
  59. if check is not None:
  60. sympy_deprecation_warning(
  61. "Passing check to HadamardProduct is deprecated and the check argument will be removed in a future version.",
  62. deprecated_since_version="1.11",
  63. active_deprecations_target='remove-check-argument-from-matrix-operations')
  64. if check is not False:
  65. validate(*args)
  66. obj = super().__new__(cls, *args)
  67. if evaluate:
  68. obj = obj.doit(deep=False)
  69. return obj
  70. @property
  71. def shape(self):
  72. return self.args[0].shape
  73. def _entry(self, i, j, **kwargs):
  74. return Mul(*[arg._entry(i, j, **kwargs) for arg in self.args])
  75. def _eval_transpose(self):
  76. from sympy.matrices.expressions.transpose import transpose
  77. return HadamardProduct(*list(map(transpose, self.args)))
  78. def doit(self, **hints):
  79. expr = self.func(*(i.doit(**hints) for i in self.args))
  80. # Check for explicit matrices:
  81. from sympy.matrices.matrices import MatrixBase
  82. from sympy.matrices.immutable import ImmutableMatrix
  83. explicit = [i for i in expr.args if isinstance(i, MatrixBase)]
  84. if explicit:
  85. remainder = [i for i in expr.args if i not in explicit]
  86. expl_mat = ImmutableMatrix([
  87. Mul.fromiter(i) for i in zip(*explicit)
  88. ]).reshape(*self.shape)
  89. expr = HadamardProduct(*([expl_mat] + remainder))
  90. return canonicalize(expr)
  91. def _eval_derivative(self, x):
  92. terms = []
  93. args = list(self.args)
  94. for i in range(len(args)):
  95. factors = args[:i] + [args[i].diff(x)] + args[i+1:]
  96. terms.append(hadamard_product(*factors))
  97. return Add.fromiter(terms)
  98. def _eval_derivative_matrix_lines(self, x):
  99. from sympy.tensor.array.expressions.array_expressions import ArrayDiagonal
  100. from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
  101. from sympy.matrices.expressions.matexpr import _make_matrix
  102. with_x_ind = [i for i, arg in enumerate(self.args) if arg.has(x)]
  103. lines = []
  104. for ind in with_x_ind:
  105. left_args = self.args[:ind]
  106. right_args = self.args[ind+1:]
  107. d = self.args[ind]._eval_derivative_matrix_lines(x)
  108. hadam = hadamard_product(*(right_args + left_args))
  109. diagonal = [(0, 2), (3, 4)]
  110. diagonal = [e for j, e in enumerate(diagonal) if self.shape[j] != 1]
  111. for i in d:
  112. l1 = i._lines[i._first_line_index]
  113. l2 = i._lines[i._second_line_index]
  114. subexpr = ExprBuilder(
  115. ArrayDiagonal,
  116. [
  117. ExprBuilder(
  118. ArrayTensorProduct,
  119. [
  120. ExprBuilder(_make_matrix, [l1]),
  121. hadam,
  122. ExprBuilder(_make_matrix, [l2]),
  123. ]
  124. ),
  125. *diagonal],
  126. )
  127. i._first_pointer_parent = subexpr.args[0].args[0].args
  128. i._first_pointer_index = 0
  129. i._second_pointer_parent = subexpr.args[0].args[2].args
  130. i._second_pointer_index = 0
  131. i._lines = [subexpr]
  132. lines.append(i)
  133. return lines
  134. # TODO Implement algorithm for rewriting Hadamard product as diagonal matrix
  135. # if matmul identy matrix is multiplied.
  136. def canonicalize(x):
  137. """Canonicalize the Hadamard product ``x`` with mathematical properties.
  138. Examples
  139. ========
  140. >>> from sympy import MatrixSymbol, HadamardProduct
  141. >>> from sympy import OneMatrix, ZeroMatrix
  142. >>> from sympy.matrices.expressions.hadamard import canonicalize
  143. >>> from sympy import init_printing
  144. >>> init_printing(use_unicode=False)
  145. >>> A = MatrixSymbol('A', 2, 2)
  146. >>> B = MatrixSymbol('B', 2, 2)
  147. >>> C = MatrixSymbol('C', 2, 2)
  148. Hadamard product associativity:
  149. >>> X = HadamardProduct(A, HadamardProduct(B, C))
  150. >>> X
  151. A.*(B.*C)
  152. >>> canonicalize(X)
  153. A.*B.*C
  154. Hadamard product commutativity:
  155. >>> X = HadamardProduct(A, B)
  156. >>> Y = HadamardProduct(B, A)
  157. >>> X
  158. A.*B
  159. >>> Y
  160. B.*A
  161. >>> canonicalize(X)
  162. A.*B
  163. >>> canonicalize(Y)
  164. A.*B
  165. Hadamard product identity:
  166. >>> X = HadamardProduct(A, OneMatrix(2, 2))
  167. >>> X
  168. A.*1
  169. >>> canonicalize(X)
  170. A
  171. Absorbing element of Hadamard product:
  172. >>> X = HadamardProduct(A, ZeroMatrix(2, 2))
  173. >>> X
  174. A.*0
  175. >>> canonicalize(X)
  176. 0
  177. Rewriting to Hadamard Power
  178. >>> X = HadamardProduct(A, A, A)
  179. >>> X
  180. A.*A.*A
  181. >>> canonicalize(X)
  182. .3
  183. A
  184. Notes
  185. =====
  186. As the Hadamard product is associative, nested products can be flattened.
  187. The Hadamard product is commutative so that factors can be sorted for
  188. canonical form.
  189. A matrix of only ones is an identity for Hadamard product,
  190. so every matrices of only ones can be removed.
  191. Any zero matrix will make the whole product a zero matrix.
  192. Duplicate elements can be collected and rewritten as HadamardPower
  193. References
  194. ==========
  195. .. [1] https://en.wikipedia.org/wiki/Hadamard_product_(matrices)
  196. """
  197. # Associativity
  198. rule = condition(
  199. lambda x: isinstance(x, HadamardProduct),
  200. flatten
  201. )
  202. fun = exhaust(rule)
  203. x = fun(x)
  204. # Identity
  205. fun = condition(
  206. lambda x: isinstance(x, HadamardProduct),
  207. rm_id(lambda x: isinstance(x, OneMatrix))
  208. )
  209. x = fun(x)
  210. # Absorbing by Zero Matrix
  211. def absorb(x):
  212. if any(isinstance(c, ZeroMatrix) for c in x.args):
  213. return ZeroMatrix(*x.shape)
  214. else:
  215. return x
  216. fun = condition(
  217. lambda x: isinstance(x, HadamardProduct),
  218. absorb
  219. )
  220. x = fun(x)
  221. # Rewriting with HadamardPower
  222. if isinstance(x, HadamardProduct):
  223. tally = Counter(x.args)
  224. new_arg = []
  225. for base, exp in tally.items():
  226. if exp == 1:
  227. new_arg.append(base)
  228. else:
  229. new_arg.append(HadamardPower(base, exp))
  230. x = HadamardProduct(*new_arg)
  231. # Commutativity
  232. fun = condition(
  233. lambda x: isinstance(x, HadamardProduct),
  234. sort(default_sort_key)
  235. )
  236. x = fun(x)
  237. # Unpacking
  238. x = unpack(x)
  239. return x
  240. def hadamard_power(base, exp):
  241. base = sympify(base)
  242. exp = sympify(exp)
  243. if exp == 1:
  244. return base
  245. if not base.is_Matrix:
  246. return base**exp
  247. if exp.is_Matrix:
  248. raise ValueError("cannot raise expression to a matrix")
  249. return HadamardPower(base, exp)
  250. class HadamardPower(MatrixExpr):
  251. r"""
  252. Elementwise power of matrix expressions
  253. Parameters
  254. ==========
  255. base : scalar or matrix
  256. exp : scalar or matrix
  257. Notes
  258. =====
  259. There are four definitions for the hadamard power which can be used.
  260. Let's consider `A, B` as `(m, n)` matrices, and `a, b` as scalars.
  261. Matrix raised to a scalar exponent:
  262. .. math::
  263. A^{\circ b} = \begin{bmatrix}
  264. A_{0, 0}^b & A_{0, 1}^b & \cdots & A_{0, n-1}^b \\
  265. A_{1, 0}^b & A_{1, 1}^b & \cdots & A_{1, n-1}^b \\
  266. \vdots & \vdots & \ddots & \vdots \\
  267. A_{m-1, 0}^b & A_{m-1, 1}^b & \cdots & A_{m-1, n-1}^b
  268. \end{bmatrix}
  269. Scalar raised to a matrix exponent:
  270. .. math::
  271. a^{\circ B} = \begin{bmatrix}
  272. a^{B_{0, 0}} & a^{B_{0, 1}} & \cdots & a^{B_{0, n-1}} \\
  273. a^{B_{1, 0}} & a^{B_{1, 1}} & \cdots & a^{B_{1, n-1}} \\
  274. \vdots & \vdots & \ddots & \vdots \\
  275. a^{B_{m-1, 0}} & a^{B_{m-1, 1}} & \cdots & a^{B_{m-1, n-1}}
  276. \end{bmatrix}
  277. Matrix raised to a matrix exponent:
  278. .. math::
  279. A^{\circ B} = \begin{bmatrix}
  280. A_{0, 0}^{B_{0, 0}} & A_{0, 1}^{B_{0, 1}} &
  281. \cdots & A_{0, n-1}^{B_{0, n-1}} \\
  282. A_{1, 0}^{B_{1, 0}} & A_{1, 1}^{B_{1, 1}} &
  283. \cdots & A_{1, n-1}^{B_{1, n-1}} \\
  284. \vdots & \vdots &
  285. \ddots & \vdots \\
  286. A_{m-1, 0}^{B_{m-1, 0}} & A_{m-1, 1}^{B_{m-1, 1}} &
  287. \cdots & A_{m-1, n-1}^{B_{m-1, n-1}}
  288. \end{bmatrix}
  289. Scalar raised to a scalar exponent:
  290. .. math::
  291. a^{\circ b} = a^b
  292. """
  293. def __new__(cls, base, exp):
  294. base = sympify(base)
  295. exp = sympify(exp)
  296. if base.is_scalar and exp.is_scalar:
  297. return base ** exp
  298. if isinstance(base, MatrixExpr) and isinstance(exp, MatrixExpr):
  299. validate(base, exp)
  300. obj = super().__new__(cls, base, exp)
  301. return obj
  302. @property
  303. def base(self):
  304. return self._args[0]
  305. @property
  306. def exp(self):
  307. return self._args[1]
  308. @property
  309. def shape(self):
  310. if self.base.is_Matrix:
  311. return self.base.shape
  312. return self.exp.shape
  313. def _entry(self, i, j, **kwargs):
  314. base = self.base
  315. exp = self.exp
  316. if base.is_Matrix:
  317. a = base._entry(i, j, **kwargs)
  318. elif base.is_scalar:
  319. a = base
  320. else:
  321. raise ValueError(
  322. 'The base {} must be a scalar or a matrix.'.format(base))
  323. if exp.is_Matrix:
  324. b = exp._entry(i, j, **kwargs)
  325. elif exp.is_scalar:
  326. b = exp
  327. else:
  328. raise ValueError(
  329. 'The exponent {} must be a scalar or a matrix.'.format(exp))
  330. return a ** b
  331. def _eval_transpose(self):
  332. from sympy.matrices.expressions.transpose import transpose
  333. return HadamardPower(transpose(self.base), self.exp)
  334. def _eval_derivative(self, x):
  335. dexp = self.exp.diff(x)
  336. logbase = self.base.applyfunc(log)
  337. dlbase = logbase.diff(x)
  338. return hadamard_product(
  339. dexp*logbase + self.exp*dlbase,
  340. self
  341. )
  342. def _eval_derivative_matrix_lines(self, x):
  343. from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
  344. from sympy.tensor.array.expressions.array_expressions import ArrayDiagonal
  345. from sympy.matrices.expressions.matexpr import _make_matrix
  346. lr = self.base._eval_derivative_matrix_lines(x)
  347. for i in lr:
  348. diagonal = [(1, 2), (3, 4)]
  349. diagonal = [e for j, e in enumerate(diagonal) if self.base.shape[j] != 1]
  350. l1 = i._lines[i._first_line_index]
  351. l2 = i._lines[i._second_line_index]
  352. subexpr = ExprBuilder(
  353. ArrayDiagonal,
  354. [
  355. ExprBuilder(
  356. ArrayTensorProduct,
  357. [
  358. ExprBuilder(_make_matrix, [l1]),
  359. self.exp*hadamard_power(self.base, self.exp-1),
  360. ExprBuilder(_make_matrix, [l2]),
  361. ]
  362. ),
  363. *diagonal],
  364. validator=ArrayDiagonal._validate
  365. )
  366. i._first_pointer_parent = subexpr.args[0].args[0].args
  367. i._first_pointer_index = 0
  368. i._first_line_index = 0
  369. i._second_pointer_parent = subexpr.args[0].args[2].args
  370. i._second_pointer_index = 0
  371. i._second_line_index = 0
  372. i._lines = [subexpr]
  373. return lr