decorators.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. """
  2. SymPy core decorators.
  3. The purpose of this module is to expose decorators without any other
  4. dependencies, so that they can be easily imported anywhere in sympy/core.
  5. """
  6. from functools import wraps
  7. from .sympify import SympifyError, sympify
  8. def _sympifyit(arg, retval=None):
  9. """
  10. decorator to smartly _sympify function arguments
  11. Explanation
  12. ===========
  13. @_sympifyit('other', NotImplemented)
  14. def add(self, other):
  15. ...
  16. In add, other can be thought of as already being a SymPy object.
  17. If it is not, the code is likely to catch an exception, then other will
  18. be explicitly _sympified, and the whole code restarted.
  19. if _sympify(arg) fails, NotImplemented will be returned
  20. See also
  21. ========
  22. __sympifyit
  23. """
  24. def deco(func):
  25. return __sympifyit(func, arg, retval)
  26. return deco
  27. def __sympifyit(func, arg, retval=None):
  28. """Decorator to _sympify `arg` argument for function `func`.
  29. Do not use directly -- use _sympifyit instead.
  30. """
  31. # we support f(a,b) only
  32. if not func.__code__.co_argcount:
  33. raise LookupError("func not found")
  34. # only b is _sympified
  35. assert func.__code__.co_varnames[1] == arg
  36. if retval is None:
  37. @wraps(func)
  38. def __sympifyit_wrapper(a, b):
  39. return func(a, sympify(b, strict=True))
  40. else:
  41. @wraps(func)
  42. def __sympifyit_wrapper(a, b):
  43. try:
  44. # If an external class has _op_priority, it knows how to deal
  45. # with SymPy objects. Otherwise, it must be converted.
  46. if not hasattr(b, '_op_priority'):
  47. b = sympify(b, strict=True)
  48. return func(a, b)
  49. except SympifyError:
  50. return retval
  51. return __sympifyit_wrapper
  52. def call_highest_priority(method_name):
  53. """A decorator for binary special methods to handle _op_priority.
  54. Explanation
  55. ===========
  56. Binary special methods in Expr and its subclasses use a special attribute
  57. '_op_priority' to determine whose special method will be called to
  58. handle the operation. In general, the object having the highest value of
  59. '_op_priority' will handle the operation. Expr and subclasses that define
  60. custom binary special methods (__mul__, etc.) should decorate those
  61. methods with this decorator to add the priority logic.
  62. The ``method_name`` argument is the name of the method of the other class
  63. that will be called. Use this decorator in the following manner::
  64. # Call other.__rmul__ if other._op_priority > self._op_priority
  65. @call_highest_priority('__rmul__')
  66. def __mul__(self, other):
  67. ...
  68. # Call other.__mul__ if other._op_priority > self._op_priority
  69. @call_highest_priority('__mul__')
  70. def __rmul__(self, other):
  71. ...
  72. """
  73. def priority_decorator(func):
  74. @wraps(func)
  75. def binary_op_wrapper(self, other):
  76. if hasattr(other, '_op_priority'):
  77. if other._op_priority > self._op_priority:
  78. f = getattr(other, method_name, None)
  79. if f is not None:
  80. return f(self)
  81. return func(self, other)
  82. return binary_op_wrapper
  83. return priority_decorator
  84. def sympify_method_args(cls):
  85. '''Decorator for a class with methods that sympify arguments.
  86. Explanation
  87. ===========
  88. The sympify_method_args decorator is to be used with the sympify_return
  89. decorator for automatic sympification of method arguments. This is
  90. intended for the common idiom of writing a class like :
  91. Examples
  92. ========
  93. >>> from sympy import Basic, SympifyError, S
  94. >>> from sympy.core.sympify import _sympify
  95. >>> class MyTuple(Basic):
  96. ... def __add__(self, other):
  97. ... try:
  98. ... other = _sympify(other)
  99. ... except SympifyError:
  100. ... return NotImplemented
  101. ... if not isinstance(other, MyTuple):
  102. ... return NotImplemented
  103. ... return MyTuple(*(self.args + other.args))
  104. >>> MyTuple(S(1), S(2)) + MyTuple(S(3), S(4))
  105. MyTuple(1, 2, 3, 4)
  106. In the above it is important that we return NotImplemented when other is
  107. not sympifiable and also when the sympified result is not of the expected
  108. type. This allows the MyTuple class to be used cooperatively with other
  109. classes that overload __add__ and want to do something else in combination
  110. with instance of Tuple.
  111. Using this decorator the above can be written as
  112. >>> from sympy.core.decorators import sympify_method_args, sympify_return
  113. >>> @sympify_method_args
  114. ... class MyTuple(Basic):
  115. ... @sympify_return([('other', 'MyTuple')], NotImplemented)
  116. ... def __add__(self, other):
  117. ... return MyTuple(*(self.args + other.args))
  118. >>> MyTuple(S(1), S(2)) + MyTuple(S(3), S(4))
  119. MyTuple(1, 2, 3, 4)
  120. The idea here is that the decorators take care of the boiler-plate code
  121. for making this happen in each method that potentially needs to accept
  122. unsympified arguments. Then the body of e.g. the __add__ method can be
  123. written without needing to worry about calling _sympify or checking the
  124. type of the resulting object.
  125. The parameters for sympify_return are a list of tuples of the form
  126. (parameter_name, expected_type) and the value to return (e.g.
  127. NotImplemented). The expected_type parameter can be a type e.g. Tuple or a
  128. string 'Tuple'. Using a string is useful for specifying a Type within its
  129. class body (as in the above example).
  130. Notes: Currently sympify_return only works for methods that take a single
  131. argument (not including self). Specifying an expected_type as a string
  132. only works for the class in which the method is defined.
  133. '''
  134. # Extract the wrapped methods from each of the wrapper objects created by
  135. # the sympify_return decorator. Doing this here allows us to provide the
  136. # cls argument which is used for forward string referencing.
  137. for attrname, obj in cls.__dict__.items():
  138. if isinstance(obj, _SympifyWrapper):
  139. setattr(cls, attrname, obj.make_wrapped(cls))
  140. return cls
  141. def sympify_return(*args):
  142. '''Function/method decorator to sympify arguments automatically
  143. See the docstring of sympify_method_args for explanation.
  144. '''
  145. # Store a wrapper object for the decorated method
  146. def wrapper(func):
  147. return _SympifyWrapper(func, args)
  148. return wrapper
  149. class _SympifyWrapper:
  150. '''Internal class used by sympify_return and sympify_method_args'''
  151. def __init__(self, func, args):
  152. self.func = func
  153. self.args = args
  154. def make_wrapped(self, cls):
  155. func = self.func
  156. parameters, retval = self.args
  157. # XXX: Handle more than one parameter?
  158. [(parameter, expectedcls)] = parameters
  159. # Handle forward references to the current class using strings
  160. if expectedcls == cls.__name__:
  161. expectedcls = cls
  162. # Raise RuntimeError since this is a failure at import time and should
  163. # not be recoverable.
  164. nargs = func.__code__.co_argcount
  165. # we support f(a, b) only
  166. if nargs != 2:
  167. raise RuntimeError('sympify_return can only be used with 2 argument functions')
  168. # only b is _sympified
  169. if func.__code__.co_varnames[1] != parameter:
  170. raise RuntimeError('parameter name mismatch "%s" in %s' %
  171. (parameter, func.__name__))
  172. @wraps(func)
  173. def _func(self, other):
  174. # XXX: The check for _op_priority here should be removed. It is
  175. # needed to stop mutable matrices from being sympified to
  176. # immutable matrices which breaks things in quantum...
  177. if not hasattr(other, '_op_priority'):
  178. try:
  179. other = sympify(other, strict=True)
  180. except SympifyError:
  181. return retval
  182. if not isinstance(other, expectedcls):
  183. return retval
  184. return func(self, other)
  185. return _func