ops_dispatch.pyx 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. DISPATCHED_UFUNCS = {
  2. "add",
  3. "sub",
  4. "mul",
  5. "pow",
  6. "mod",
  7. "floordiv",
  8. "truediv",
  9. "divmod",
  10. "eq",
  11. "ne",
  12. "lt",
  13. "gt",
  14. "le",
  15. "ge",
  16. "remainder",
  17. "matmul",
  18. "or",
  19. "xor",
  20. "and",
  21. "neg",
  22. "pos",
  23. "abs",
  24. }
  25. UNARY_UFUNCS = {
  26. "neg",
  27. "pos",
  28. "abs",
  29. }
  30. UFUNC_ALIASES = {
  31. "subtract": "sub",
  32. "multiply": "mul",
  33. "floor_divide": "floordiv",
  34. "true_divide": "truediv",
  35. "power": "pow",
  36. "remainder": "mod",
  37. "divide": "truediv",
  38. "equal": "eq",
  39. "not_equal": "ne",
  40. "less": "lt",
  41. "less_equal": "le",
  42. "greater": "gt",
  43. "greater_equal": "ge",
  44. "bitwise_or": "or",
  45. "bitwise_and": "and",
  46. "bitwise_xor": "xor",
  47. "negative": "neg",
  48. "absolute": "abs",
  49. "positive": "pos",
  50. }
  51. # For op(., Array) -> Array.__r{op}__
  52. REVERSED_NAMES = {
  53. "lt": "__gt__",
  54. "le": "__ge__",
  55. "gt": "__lt__",
  56. "ge": "__le__",
  57. "eq": "__eq__",
  58. "ne": "__ne__",
  59. }
  60. def maybe_dispatch_ufunc_to_dunder_op(
  61. object self, object ufunc, str method, *inputs, **kwargs
  62. ):
  63. """
  64. Dispatch a ufunc to the equivalent dunder method.
  65. Parameters
  66. ----------
  67. self : ArrayLike
  68. The array whose dunder method we dispatch to
  69. ufunc : Callable
  70. A NumPy ufunc
  71. method : {'reduce', 'accumulate', 'reduceat', 'outer', 'at', '__call__'}
  72. inputs : ArrayLike
  73. The input arrays.
  74. kwargs : Any
  75. The additional keyword arguments, e.g. ``out``.
  76. Returns
  77. -------
  78. result : Any
  79. The result of applying the ufunc
  80. """
  81. # special has the ufuncs we dispatch to the dunder op on
  82. op_name = ufunc.__name__
  83. op_name = UFUNC_ALIASES.get(op_name, op_name)
  84. def not_implemented(*args, **kwargs):
  85. return NotImplemented
  86. if kwargs or ufunc.nin > 2:
  87. return NotImplemented
  88. if method == "__call__" and op_name in DISPATCHED_UFUNCS:
  89. if inputs[0] is self:
  90. name = f"__{op_name}__"
  91. meth = getattr(self, name, not_implemented)
  92. if op_name in UNARY_UFUNCS:
  93. assert len(inputs) == 1
  94. return meth()
  95. return meth(inputs[1])
  96. elif inputs[1] is self:
  97. name = REVERSED_NAMES.get(op_name, f"__r{op_name}__")
  98. meth = getattr(self, name, not_implemented)
  99. result = meth(inputs[0])
  100. return result
  101. else:
  102. # should not be reached, but covering our bases
  103. return NotImplemented
  104. else:
  105. return NotImplemented