value_ranges.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. import dataclasses
  2. import itertools
  3. import sympy # type: ignore[import]
  4. import operator
  5. import math
  6. import logging
  7. import torch
  8. from typing import Union
  9. log = logging.getLogger(__name__)
  10. @dataclasses.dataclass(frozen=True)
  11. class ValueRanges:
  12. lower: Union[sympy.Expr, sympy.Number, int, float, bool]
  13. upper: Union[sympy.Expr, sympy.Number, int, float, bool]
  14. def __contains__(self, x):
  15. # TODO This needs to be generalised if lower/upper are sympy.Expr
  16. assert not isinstance(x, sympy.Expr)
  17. return self.lower <= x <= self.upper
  18. @classmethod
  19. def wrap(cls, arg):
  20. if isinstance(arg, ValueRanges):
  21. return arg
  22. assert isinstance(arg, (int, float, bool))
  23. return ValueRanges(arg, arg)
  24. @classmethod
  25. def increasing_map(cls, x, fn):
  26. """map lower and upper bound with fn"""
  27. x = cls.wrap(x)
  28. return ValueRanges(fn(x.lower), fn(x.upper))
  29. @classmethod
  30. def decreasing_map(cls, x, fn):
  31. """map lower bound to upper bound and upper bound to lower bound"""
  32. x = cls.wrap(x)
  33. return ValueRanges(fn(x.upper), fn(x.lower))
  34. @classmethod
  35. def monotone_map(cls, x, fn):
  36. """check the max and min of computed upper and lower bound for the output"""
  37. x = cls.wrap(x)
  38. l = fn(x.lower)
  39. u = fn(x.upper)
  40. return ValueRanges(min(l, u), max(l, u))
  41. @classmethod
  42. def convex_min_zero_map(cls, x, fn):
  43. """the max is at one of the ends"""
  44. x = ValueRanges.wrap(x)
  45. if 0 in x:
  46. return ValueRanges(0, max(fn(x.lower), fn(x.upper)))
  47. else:
  48. return cls.monotone_map(x, fn)
  49. @classmethod
  50. def coordinatewise_increasing_map(cls, x, y, fn):
  51. """map upper and lower bounds accessing corresponding values of inputs"""
  52. x, y = cls.wrap(x), cls.wrap(y)
  53. return ValueRanges(
  54. fn(x.lower, y.lower),
  55. fn(x.upper, y.upper),
  56. )
  57. @classmethod
  58. def coordinatewise_monotone_map(cls, x, y, fn):
  59. """compute the product of all lower and upper bounds and take min and max"""
  60. x, y = cls.wrap(x), cls.wrap(y)
  61. products = [
  62. fn(a, b)
  63. for a, b in itertools.product([x.lower, x.upper], [y.lower, y.upper])
  64. ]
  65. return ValueRanges(min(products), max(products))
  66. class ValueRangeAnalysis:
  67. def __init__(self):
  68. self.name = "ValueRangeAnalysis"
  69. boolean_operators = (
  70. "eq",
  71. "ne",
  72. "lt",
  73. "gt",
  74. "le",
  75. "ge",
  76. "and_",
  77. "or_",
  78. "xor",
  79. "logical_and",
  80. "logical_or",
  81. "logical_not",
  82. )
  83. for op in boolean_operators:
  84. setattr(self, op, self.bool_handler)
  85. @staticmethod
  86. def bool_handler(*args, **kwargs):
  87. # just assuming bools can have both values
  88. return ValueRanges(sympy.false, sympy.true) # type: ignore[arg-type]
  89. @staticmethod
  90. def default_handler(*args, **kwargs):
  91. # many ops are unlikely to show up in optimizable indexing compute,
  92. # so we dont have full coverage
  93. return ValueRanges(-math.inf, math.inf)
  94. def load(self, name: str, index: sympy.Expr):
  95. return ValueRanges(-math.inf, math.inf)
  96. def store(self, name, index, value, mode=None):
  97. return
  98. def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
  99. return ValueRanges(-math.inf, math.inf)
  100. def index_expr(self, index, dtype):
  101. assert isinstance(index, ValueRanges)
  102. return index
  103. @staticmethod
  104. def to_dtype(x, dtype: torch.dtype):
  105. def is_bool(val):
  106. return isinstance(val, bool) or (
  107. hasattr(val, "is_Boolean") and val.is_Boolean
  108. )
  109. x = ValueRanges.wrap(x)
  110. low, up = x.lower, x.upper
  111. if is_bool(low):
  112. assert is_bool(up)
  113. if dtype.is_floating_point:
  114. return ValueRanges(sympy.Float(0.0), sympy.Float(1.0))
  115. else:
  116. return ValueRanges(sympy.Integer(0), sympy.Integer(1))
  117. return ValueRanges.wrap(x)
  118. @staticmethod
  119. def constant(value, dtype):
  120. # using nan makes subsequent computation throw, and for the purposes of optimization
  121. # returning -math.inf - math.inf is equivalent to giving up
  122. if math.isnan(value):
  123. return ValueRanges(-math.inf, math.inf)
  124. if isinstance(value, int):
  125. return ValueRanges(sympy.Integer(value), sympy.Integer(value))
  126. else:
  127. return ValueRanges(sympy.Float(value), sympy.Float(value))
  128. @staticmethod
  129. def reciprocal(x):
  130. x = ValueRanges.wrap(x)
  131. if 0 in x:
  132. return ValueRanges(-math.inf, math.inf)
  133. else:
  134. return ValueRanges.decreasing_map(x, lambda y: 1 / y)
  135. @staticmethod
  136. def square(x):
  137. return ValueRanges.convex_min_zero_map(x, lambda y: y * y)
  138. @staticmethod
  139. def abs(x):
  140. return ValueRanges.convex_min_zero_map(x, abs)
  141. @staticmethod
  142. def neg(x):
  143. return ValueRanges.decreasing_map(x, operator.neg)
  144. @staticmethod
  145. def truediv(a, b):
  146. b = ValueRanges.wrap(b)
  147. if 0 in b:
  148. return ValueRanges(-math.inf, math.inf)
  149. else:
  150. return ValueRangeAnalysis.mul(a, ValueRanges(1 / b.upper, 1 / b.lower))
  151. @staticmethod
  152. def div(a, b):
  153. # We think of this as floor(a / b)
  154. out = ValueRangeAnalysis.truediv(a, b)
  155. return ValueRangeAnalysis.floor(out)
  156. @staticmethod
  157. def add(a, b):
  158. return ValueRanges.coordinatewise_increasing_map(a, b, operator.add)
  159. @staticmethod
  160. def mul(a, b):
  161. return ValueRanges.coordinatewise_monotone_map(a, b, operator.mul)
  162. @staticmethod
  163. def sub(a, b):
  164. b = ValueRanges.wrap(b)
  165. return ValueRangeAnalysis.add(a, ValueRanges(-b.upper, -b.lower))
  166. @staticmethod
  167. def exp(x):
  168. return ValueRanges.increasing_map(x, sympy.functions.elementary.exponential.exp)
  169. @staticmethod
  170. def log(x):
  171. return ValueRanges.increasing_map(
  172. x, lambda y: -math.inf if y <= 0 else sympy.log(y)
  173. )
  174. @staticmethod
  175. def sqrt(x):
  176. return ValueRanges.increasing_map(x, sympy.sqrt)
  177. @staticmethod
  178. def pow(a, b):
  179. def is_integer(val):
  180. return (
  181. isinstance(val, int)
  182. or (isinstance(val, float) and val == int(val))
  183. or (hasattr(val, "is_integer") and val.is_integer)
  184. )
  185. a = ValueRanges.wrap(a)
  186. b = ValueRanges.wrap(b)
  187. if a.lower < 0 and not is_integer(b.lower):
  188. # The function is not defined
  189. return ValueRanges(-math.inf, math.inf)
  190. elif 0 in a and b.lower <= 0:
  191. return ValueRanges(-math.inf, math.inf)
  192. return ValueRanges.coordinatewise_monotone_map(a, b, operator.pow)
  193. @staticmethod
  194. def minimum(a, b):
  195. return ValueRanges.coordinatewise_increasing_map(a, b, min)
  196. @staticmethod
  197. def maximum(a, b):
  198. return ValueRanges.coordinatewise_increasing_map(a, b, max)
  199. @staticmethod
  200. def where(a, b, c):
  201. b = ValueRanges.wrap(b)
  202. c = ValueRanges.wrap(c)
  203. return ValueRanges(min(b.lower, c.lower), max(b.upper, c.upper))
  204. @staticmethod
  205. def floor(x):
  206. return ValueRangeAnalysis.floor_ceil(
  207. x, sympy.functions.elementary.integers.floor
  208. )
  209. @staticmethod
  210. def ceil(x):
  211. return ValueRangeAnalysis.floor_ceil(
  212. x, sympy.functions.elementary.integers.ceiling
  213. )
  214. @staticmethod
  215. def floor_ceil(x, fn_int):
  216. def is_integer(val):
  217. return isinstance(val, int) or (
  218. hasattr(val, "is_integer") and val.is_integer
  219. )
  220. if is_integer(x):
  221. fn = fn_int
  222. else:
  223. def fn(x):
  224. return sympy.Float(fn_int(x))
  225. return ValueRanges.increasing_map(x, fn)
  226. def __getattr__(self, name):
  227. log.warning(f"unhandled ValueRange op {name}")
  228. return self.default_handler