modularinteger.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. """Implementation of :class:`ModularInteger` class. """
  2. from __future__ import annotations
  3. from typing import Any
  4. import operator
  5. from sympy.polys.polyutils import PicklableWithSlots
  6. from sympy.polys.polyerrors import CoercionFailed
  7. from sympy.polys.domains.domainelement import DomainElement
  8. from sympy.utilities import public
  9. @public
  10. class ModularInteger(PicklableWithSlots, DomainElement):
  11. """A class representing a modular integer. """
  12. mod, dom, sym, _parent = None, None, None, None
  13. __slots__ = ('val',)
  14. def parent(self):
  15. return self._parent
  16. def __init__(self, val):
  17. if isinstance(val, self.__class__):
  18. self.val = val.val % self.mod
  19. else:
  20. self.val = self.dom.convert(val) % self.mod
  21. def __hash__(self):
  22. return hash((self.val, self.mod))
  23. def __repr__(self):
  24. return "%s(%s)" % (self.__class__.__name__, self.val)
  25. def __str__(self):
  26. return "%s mod %s" % (self.val, self.mod)
  27. def __int__(self):
  28. return int(self.to_int())
  29. def to_int(self):
  30. if self.sym:
  31. if self.val <= self.mod // 2:
  32. return self.val
  33. else:
  34. return self.val - self.mod
  35. else:
  36. return self.val
  37. def __pos__(self):
  38. return self
  39. def __neg__(self):
  40. return self.__class__(-self.val)
  41. @classmethod
  42. def _get_val(cls, other):
  43. if isinstance(other, cls):
  44. return other.val
  45. else:
  46. try:
  47. return cls.dom.convert(other)
  48. except CoercionFailed:
  49. return None
  50. def __add__(self, other):
  51. val = self._get_val(other)
  52. if val is not None:
  53. return self.__class__(self.val + val)
  54. else:
  55. return NotImplemented
  56. def __radd__(self, other):
  57. return self.__add__(other)
  58. def __sub__(self, other):
  59. val = self._get_val(other)
  60. if val is not None:
  61. return self.__class__(self.val - val)
  62. else:
  63. return NotImplemented
  64. def __rsub__(self, other):
  65. return (-self).__add__(other)
  66. def __mul__(self, other):
  67. val = self._get_val(other)
  68. if val is not None:
  69. return self.__class__(self.val * val)
  70. else:
  71. return NotImplemented
  72. def __rmul__(self, other):
  73. return self.__mul__(other)
  74. def __truediv__(self, other):
  75. val = self._get_val(other)
  76. if val is not None:
  77. return self.__class__(self.val * self._invert(val))
  78. else:
  79. return NotImplemented
  80. def __rtruediv__(self, other):
  81. return self.invert().__mul__(other)
  82. def __mod__(self, other):
  83. val = self._get_val(other)
  84. if val is not None:
  85. return self.__class__(self.val % val)
  86. else:
  87. return NotImplemented
  88. def __rmod__(self, other):
  89. val = self._get_val(other)
  90. if val is not None:
  91. return self.__class__(val % self.val)
  92. else:
  93. return NotImplemented
  94. def __pow__(self, exp):
  95. if not exp:
  96. return self.__class__(self.dom.one)
  97. if exp < 0:
  98. val, exp = self.invert().val, -exp
  99. else:
  100. val = self.val
  101. return self.__class__(pow(val, int(exp), self.mod))
  102. def _compare(self, other, op):
  103. val = self._get_val(other)
  104. if val is not None:
  105. return op(self.val, val % self.mod)
  106. else:
  107. return NotImplemented
  108. def __eq__(self, other):
  109. return self._compare(other, operator.eq)
  110. def __ne__(self, other):
  111. return self._compare(other, operator.ne)
  112. def __lt__(self, other):
  113. return self._compare(other, operator.lt)
  114. def __le__(self, other):
  115. return self._compare(other, operator.le)
  116. def __gt__(self, other):
  117. return self._compare(other, operator.gt)
  118. def __ge__(self, other):
  119. return self._compare(other, operator.ge)
  120. def __bool__(self):
  121. return bool(self.val)
  122. @classmethod
  123. def _invert(cls, value):
  124. return cls.dom.invert(value, cls.mod)
  125. def invert(self):
  126. return self.__class__(self._invert(self.val))
  127. _modular_integer_cache: dict[tuple[Any, Any, Any], type[ModularInteger]] = {}
  128. def ModularIntegerFactory(_mod, _dom, _sym, parent):
  129. """Create custom class for specific integer modulus."""
  130. try:
  131. _mod = _dom.convert(_mod)
  132. except CoercionFailed:
  133. ok = False
  134. else:
  135. ok = True
  136. if not ok or _mod < 1:
  137. raise ValueError("modulus must be a positive integer, got %s" % _mod)
  138. key = _mod, _dom, _sym
  139. try:
  140. cls = _modular_integer_cache[key]
  141. except KeyError:
  142. class cls(ModularInteger):
  143. mod, dom, sym = _mod, _dom, _sym
  144. _parent = parent
  145. if _sym:
  146. cls.__name__ = "SymmetricModularIntegerMod%s" % _mod
  147. else:
  148. cls.__name__ = "ModularIntegerMod%s" % _mod
  149. _modular_integer_cache[key] = cls
  150. return cls