pythonmpq.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. """
  2. PythonMPQ: Rational number type based on Python integers.
  3. This class is intended as a pure Python fallback for when gmpy2 is not
  4. installed. If gmpy2 is installed then its mpq type will be used instead. The
  5. mpq type is around 20x faster. We could just use the stdlib Fraction class
  6. here but that is slower:
  7. from fractions import Fraction
  8. from sympy.external.pythonmpq import PythonMPQ
  9. nums = range(1000)
  10. dens = range(5, 1005)
  11. rats = [Fraction(n, d) for n, d in zip(nums, dens)]
  12. sum(rats) # <--- 24 milliseconds
  13. rats = [PythonMPQ(n, d) for n, d in zip(nums, dens)]
  14. sum(rats) # <--- 7 milliseconds
  15. Both mpq and Fraction have some awkward features like the behaviour of
  16. division with // and %:
  17. >>> from fractions import Fraction
  18. >>> Fraction(2, 3) % Fraction(1, 4)
  19. 1/6
  20. For the QQ domain we do not want this behaviour because there should be no
  21. remainder when dividing rational numbers. SymPy does not make use of this
  22. aspect of mpq when gmpy2 is installed. Since this class is a fallback for that
  23. case we do not bother implementing e.g. __mod__ so that we can be sure we
  24. are not using it when gmpy2 is installed either.
  25. """
  26. import operator
  27. from math import gcd
  28. from decimal import Decimal
  29. from fractions import Fraction
  30. import sys
  31. from typing import Tuple as tTuple, Type
  32. # Used for __hash__
  33. _PyHASH_MODULUS = sys.hash_info.modulus
  34. _PyHASH_INF = sys.hash_info.inf
  35. class PythonMPQ:
  36. """Rational number implementation that is intended to be compatible with
  37. gmpy2's mpq.
  38. Also slightly faster than fractions.Fraction.
  39. PythonMPQ should be treated as immutable although no effort is made to
  40. prevent mutation (since that might slow down calculations).
  41. """
  42. __slots__ = ('numerator', 'denominator')
  43. def __new__(cls, numerator, denominator=None):
  44. """Construct PythonMPQ with gcd computation and checks"""
  45. if denominator is not None:
  46. #
  47. # PythonMPQ(n, d): require n and d to be int and d != 0
  48. #
  49. if isinstance(numerator, int) and isinstance(denominator, int):
  50. # This is the slow part:
  51. divisor = gcd(numerator, denominator)
  52. numerator //= divisor
  53. denominator //= divisor
  54. return cls._new_check(numerator, denominator)
  55. else:
  56. #
  57. # PythonMPQ(q)
  58. #
  59. # Here q can be PythonMPQ, int, Decimal, float, Fraction or str
  60. #
  61. if isinstance(numerator, int):
  62. return cls._new(numerator, 1)
  63. elif isinstance(numerator, PythonMPQ):
  64. return cls._new(numerator.numerator, numerator.denominator)
  65. # Let Fraction handle Decimal/float conversion and str parsing
  66. if isinstance(numerator, (Decimal, float, str)):
  67. numerator = Fraction(numerator)
  68. if isinstance(numerator, Fraction):
  69. return cls._new(numerator.numerator, numerator.denominator)
  70. #
  71. # Reject everything else. This is more strict than mpq which allows
  72. # things like mpq(Fraction, Fraction) or mpq(Decimal, any). The mpq
  73. # behaviour is somewhat inconsistent so we choose to accept only a
  74. # more strict subset of what mpq allows.
  75. #
  76. raise TypeError("PythonMPQ() requires numeric or string argument")
  77. @classmethod
  78. def _new_check(cls, numerator, denominator):
  79. """Construct PythonMPQ, check divide by zero and canonicalize signs"""
  80. if not denominator:
  81. raise ZeroDivisionError(f'Zero divisor {numerator}/{denominator}')
  82. elif denominator < 0:
  83. numerator = -numerator
  84. denominator = -denominator
  85. return cls._new(numerator, denominator)
  86. @classmethod
  87. def _new(cls, numerator, denominator):
  88. """Construct PythonMPQ efficiently (no checks)"""
  89. obj = super().__new__(cls)
  90. obj.numerator = numerator
  91. obj.denominator = denominator
  92. return obj
  93. def __int__(self):
  94. """Convert to int (truncates towards zero)"""
  95. p, q = self.numerator, self.denominator
  96. if p < 0:
  97. return -(-p//q)
  98. return p//q
  99. def __float__(self):
  100. """Convert to float (approximately)"""
  101. return self.numerator / self.denominator
  102. def __bool__(self):
  103. """True/False if nonzero/zero"""
  104. return bool(self.numerator)
  105. def __eq__(self, other):
  106. """Compare equal with PythonMPQ, int, float, Decimal or Fraction"""
  107. if isinstance(other, PythonMPQ):
  108. return (self.numerator == other.numerator
  109. and self.denominator == other.denominator)
  110. elif isinstance(other, self._compatible_types):
  111. return self.__eq__(PythonMPQ(other))
  112. else:
  113. return NotImplemented
  114. def __hash__(self):
  115. """hash - same as mpq/Fraction"""
  116. try:
  117. dinv = pow(self.denominator, -1, _PyHASH_MODULUS)
  118. except ValueError:
  119. hash_ = _PyHASH_INF
  120. else:
  121. hash_ = hash(hash(abs(self.numerator)) * dinv)
  122. result = hash_ if self.numerator >= 0 else -hash_
  123. return -2 if result == -1 else result
  124. def __reduce__(self):
  125. """Deconstruct for pickling"""
  126. return type(self), (self.numerator, self.denominator)
  127. def __str__(self):
  128. """Convert to string"""
  129. if self.denominator != 1:
  130. return f"{self.numerator}/{self.denominator}"
  131. else:
  132. return f"{self.numerator}"
  133. def __repr__(self):
  134. """Convert to string"""
  135. return f"MPQ({self.numerator},{self.denominator})"
  136. def _cmp(self, other, op):
  137. """Helper for lt/le/gt/ge"""
  138. if not isinstance(other, self._compatible_types):
  139. return NotImplemented
  140. lhs = self.numerator * other.denominator
  141. rhs = other.numerator * self.denominator
  142. return op(lhs, rhs)
  143. def __lt__(self, other):
  144. """self < other"""
  145. return self._cmp(other, operator.lt)
  146. def __le__(self, other):
  147. """self <= other"""
  148. return self._cmp(other, operator.le)
  149. def __gt__(self, other):
  150. """self > other"""
  151. return self._cmp(other, operator.gt)
  152. def __ge__(self, other):
  153. """self >= other"""
  154. return self._cmp(other, operator.ge)
  155. def __abs__(self):
  156. """abs(q)"""
  157. return self._new(abs(self.numerator), self.denominator)
  158. def __pos__(self):
  159. """+q"""
  160. return self
  161. def __neg__(self):
  162. """-q"""
  163. return self._new(-self.numerator, self.denominator)
  164. def __add__(self, other):
  165. """q1 + q2"""
  166. if isinstance(other, PythonMPQ):
  167. #
  168. # This is much faster than the naive method used in the stdlib
  169. # fractions module. Not sure where this method comes from
  170. # though...
  171. #
  172. # Compare timings for something like:
  173. # nums = range(1000)
  174. # rats = [PythonMPQ(n, d) for n, d in zip(nums[:-5], nums[5:])]
  175. # sum(rats) # <-- time this
  176. #
  177. ap, aq = self.numerator, self.denominator
  178. bp, bq = other.numerator, other.denominator
  179. g = gcd(aq, bq)
  180. if g == 1:
  181. p = ap*bq + aq*bp
  182. q = bq*aq
  183. else:
  184. q1, q2 = aq//g, bq//g
  185. p, q = ap*q2 + bp*q1, q1*q2
  186. g2 = gcd(p, g)
  187. p, q = (p // g2), q * (g // g2)
  188. elif isinstance(other, int):
  189. p = self.numerator + self.denominator * other
  190. q = self.denominator
  191. else:
  192. return NotImplemented
  193. return self._new(p, q)
  194. def __radd__(self, other):
  195. """z1 + q2"""
  196. if isinstance(other, int):
  197. p = self.numerator + self.denominator * other
  198. q = self.denominator
  199. return self._new(p, q)
  200. else:
  201. return NotImplemented
  202. def __sub__(self ,other):
  203. """q1 - q2"""
  204. if isinstance(other, PythonMPQ):
  205. ap, aq = self.numerator, self.denominator
  206. bp, bq = other.numerator, other.denominator
  207. g = gcd(aq, bq)
  208. if g == 1:
  209. p = ap*bq - aq*bp
  210. q = bq*aq
  211. else:
  212. q1, q2 = aq//g, bq//g
  213. p, q = ap*q2 - bp*q1, q1*q2
  214. g2 = gcd(p, g)
  215. p, q = (p // g2), q * (g // g2)
  216. elif isinstance(other, int):
  217. p = self.numerator - self.denominator*other
  218. q = self.denominator
  219. else:
  220. return NotImplemented
  221. return self._new(p, q)
  222. def __rsub__(self, other):
  223. """z1 - q2"""
  224. if isinstance(other, int):
  225. p = self.denominator * other - self.numerator
  226. q = self.denominator
  227. return self._new(p, q)
  228. else:
  229. return NotImplemented
  230. def __mul__(self, other):
  231. """q1 * q2"""
  232. if isinstance(other, PythonMPQ):
  233. ap, aq = self.numerator, self.denominator
  234. bp, bq = other.numerator, other.denominator
  235. x1 = gcd(ap, bq)
  236. x2 = gcd(bp, aq)
  237. p, q = ((ap//x1)*(bp//x2), (aq//x2)*(bq//x1))
  238. elif isinstance(other, int):
  239. x = gcd(other, self.denominator)
  240. p = self.numerator*(other//x)
  241. q = self.denominator//x
  242. else:
  243. return NotImplemented
  244. return self._new(p, q)
  245. def __rmul__(self, other):
  246. """z1 * q2"""
  247. if isinstance(other, int):
  248. x = gcd(self.denominator, other)
  249. p = self.numerator*(other//x)
  250. q = self.denominator//x
  251. return self._new(p, q)
  252. else:
  253. return NotImplemented
  254. def __pow__(self, exp):
  255. """q ** z"""
  256. p, q = self.numerator, self.denominator
  257. if exp < 0:
  258. p, q, exp = q, p, -exp
  259. return self._new_check(p**exp, q**exp)
  260. def __truediv__(self, other):
  261. """q1 / q2"""
  262. if isinstance(other, PythonMPQ):
  263. ap, aq = self.numerator, self.denominator
  264. bp, bq = other.numerator, other.denominator
  265. x1 = gcd(ap, bp)
  266. x2 = gcd(bq, aq)
  267. p, q = ((ap//x1)*(bq//x2), (aq//x2)*(bp//x1))
  268. elif isinstance(other, int):
  269. x = gcd(other, self.numerator)
  270. p = self.numerator//x
  271. q = self.denominator*(other//x)
  272. else:
  273. return NotImplemented
  274. return self._new_check(p, q)
  275. def __rtruediv__(self, other):
  276. """z / q"""
  277. if isinstance(other, int):
  278. x = gcd(self.numerator, other)
  279. p = self.denominator*(other//x)
  280. q = self.numerator//x
  281. return self._new_check(p, q)
  282. else:
  283. return NotImplemented
  284. _compatible_types: tTuple[Type, ...] = ()
  285. #
  286. # These are the types that PythonMPQ will interoperate with for operations
  287. # and comparisons such as ==, + etc. We define this down here so that we can
  288. # include PythonMPQ in the list as well.
  289. #
  290. PythonMPQ._compatible_types = (PythonMPQ, int, Decimal, Fraction)