optimize_indexing.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. import functools
  2. import logging
  3. import math
  4. from typing import Dict, Iterable, Union
  5. import sympy
  6. import torch
  7. from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
  8. from .ir import FloorDiv, InterpreterShim, LoopBody, ModularIndexing
  9. from .utils import sympy_subs
  10. from .virtualized import V
  11. log = logging.getLogger(__name__)
  12. def dominated_nodes(
  13. initial_queue: Union[torch.fx.Node, Iterable[torch.fx.Node]], skip_filter=None
  14. ):
  15. """Returns the set of nodes whose values depend on those within initial_queue"""
  16. if isinstance(initial_queue, torch.fx.Node):
  17. initial_queue = [initial_queue]
  18. dominated_set = set(initial_queue)
  19. while initial_queue:
  20. node = initial_queue.pop()
  21. for user in node.users:
  22. if skip_filter and skip_filter(user):
  23. continue
  24. if user not in dominated_set:
  25. dominated_set.add(user)
  26. initial_queue.append(user)
  27. return dominated_set
  28. def val_expressable_in_32_bits(val):
  29. if hasattr(val, "is_Boolean") and val.is_Boolean:
  30. return True
  31. if isinstance(val, sympy.Expr):
  32. assert val.is_constant()
  33. if val.is_Integer or val.is_Boolean:
  34. val = int(val)
  35. else:
  36. val = float(val)
  37. # bound within mantissa
  38. if isinstance(val, float):
  39. return val <= (2**24) and val >= -(2**24)
  40. if isinstance(val, int):
  41. iinfo = torch.iinfo(torch.int32)
  42. return val <= iinfo.max and val >= iinfo.min
  43. raise Exception(f"Unexpected value {val}")
  44. def range_expressable_in_32_bits(range):
  45. return val_expressable_in_32_bits(range.lower) and val_expressable_in_32_bits(
  46. range.upper
  47. )
  48. class OptimizeIndexing:
  49. """
  50. Performs Value Range Analysis on LoopBody's fx graph to reduce precision of
  51. intermediaries from int64 to int32. This is an important optimization for indexing
  52. kernels such as Upsample and Interpolate.
  53. """
  54. def __init__(
  55. self,
  56. loop_body: LoopBody,
  57. indices_ranges: Dict[sympy.Symbol, int],
  58. indexing_exprs: Dict[str, sympy.Expr],
  59. ):
  60. self.loop_body = loop_body
  61. self.indices_range = indices_ranges
  62. self.indexing_exprs = indexing_exprs
  63. self.replacement_vals = {}
  64. self.interp_env = {}
  65. self.submodules = self.swap_submodules(dict(loop_body.submodules))
  66. indirect_var_set = set(loop_body.indirect_vars)
  67. self.index_indirect_dependecies = {
  68. index: expr.free_symbols & indirect_var_set
  69. for index, expr in indexing_exprs.items()
  70. }
  71. self.all_graphs = [loop_body.root_block.graph] + [
  72. block.graph for block in loop_body.subblocks.values()
  73. ]
  74. for k, v in indices_ranges.items():
  75. self.replace_indirect(k, ValueRanges(0, v))
  76. # avoid computing these values, pessimistically assume that they are unbounded
  77. self.tensor_values_set = dominated_nodes(
  78. [
  79. node
  80. for node in self.all_nodes
  81. if node.target in ["load", "reduction"]
  82. or "masked_subblock" in node.target
  83. ]
  84. )
  85. def run(self):
  86. """Compute Value Ranges and try reduce precision of 'to_dtype' nodes to int32 where possible"""
  87. int64_dtype_nodes = [
  88. node
  89. for node in self.all_nodes
  90. if (
  91. node.target == "to_dtype"
  92. and node.args[2] == torch.int64
  93. and node not in self.tensor_values_set
  94. )
  95. ]
  96. if not int64_dtype_nodes:
  97. return
  98. for node in self.tensor_values_set:
  99. # we need to evaluate masked_subblock to recurse, and we need to set indirect values
  100. if (
  101. "masked_subblock" not in node.target
  102. and "set_indirect" not in node.target
  103. ):
  104. self.interp_env[node] = torch._inductor.optimize_indexing.ValueRanges(
  105. -math.inf, math.inf
  106. )
  107. interpreter = InterpreterShim(self.loop_body.root_block.graph, self.submodules)
  108. interpreter.run(V.get_ops_handler(), initial_env=self.interp_env)
  109. # TODO - if dominated node of one to_dtype is not expressible in int32,
  110. # we should short circuit another to_dtype node if that node also dominates
  111. for node in int64_dtype_nodes:
  112. self.try_to_reduce_precision(node)
  113. def try_to_reduce_precision(self, node):
  114. # if a downstream use of a node explicitly converts to int32, or float16/float32/float64,
  115. # then it's precision is set for that chain of uses, and we don't need to consider those
  116. # dominated values
  117. def skip_filter(node):
  118. return node.target == "to_dtype" and node.args[2] in (
  119. torch.int32,
  120. torch.float32,
  121. torch.float64,
  122. )
  123. # TODO - there are dominated uses whose dtype does not depend on whether
  124. # we reduce the precision here, e.g. add(int64, int64) one of the args can be reduced to
  125. # int32 without changing the output precision of the node. this case hasn't shown up
  126. for dominated in dominated_nodes(node, skip_filter):
  127. if dominated.target in ["store", "output"]:
  128. continue
  129. if "set_indirect" in dominated.target:
  130. idx = int(dominated.target[len("set_indirect") :])
  131. indirect_var = self.loop_body.indirect_vars[idx]
  132. for index, indirect_vals in self.index_indirect_dependecies.items():
  133. if indirect_var in indirect_vals:
  134. index_val = self.replacement_vals[index]
  135. if math.isinf(index_val.lower) or math.isinf(index_val.upper):
  136. return
  137. # all indices are integers, so make sure that we
  138. # use the bounds of integers instead of floats.
  139. # TODO - not sure if we should be doing int/float casts while tracing,
  140. # might interfere with sympy.
  141. index_val_int = ValueRanges(
  142. int(index_val.lower), int(index_val.upper)
  143. )
  144. if not range_expressable_in_32_bits(index_val_int):
  145. return
  146. if not range_expressable_in_32_bits(self.interp_env[dominated]):
  147. return
  148. args = list(node.args)
  149. args[2] = torch.int32
  150. node.args = tuple(args)
  151. @property
  152. def all_nodes(self):
  153. for graph in self.all_graphs:
  154. for node in graph.nodes:
  155. yield node
  156. def swap_submodules(self, submodules):
  157. keys = list(submodules.keys())
  158. for key in keys:
  159. if key == "get_index":
  160. submodules[key] = self.get_index
  161. elif "masked_subblock" in key:
  162. subblock = self.loop_body.subblocks[key]
  163. submodules[key] = functools.partial(
  164. self.masked_subblock, subblock, self.interp_env
  165. )
  166. else:
  167. assert "set_indirect" in key
  168. idx = int(key[len("set_indirect") :])
  169. var = self.loop_body.indirect_vars[idx]
  170. indirect = functools.partial(self.set_indirect, var)
  171. submodules[key] = indirect
  172. return submodules
  173. def masked_subblock(self, subblock, env, mask, value):
  174. interp = InterpreterShim(subblock.graph, self.submodules)
  175. interp.run(V.get_ops_handler(), initial_env=env)
  176. output = [node for node in subblock.graph.nodes if node.target == "output"]
  177. assert len(output) == 1
  178. # dont bother unioning with value since the load from buffer will be
  179. # pessimistically assumed to be inf anyway
  180. return interp.env[output[0]]
  181. def set_indirect(self, var, new_var):
  182. self.replace_indirect(var, new_var)
  183. return new_var
  184. def replace_indirect(self, old, new):
  185. """Swap in a variable used in indirect indexing"""
  186. assert isinstance(new, ValueRanges)
  187. self.replacement_vals[old] = new
  188. def get_index(self, name):
  189. if name in self.replacement_vals:
  190. return self.replacement_vals[name]
  191. out = self._get_index_impl(name)
  192. self.replacement_vals[name] = out
  193. return out
  194. def _get_index_impl(self, name):
  195. expr = self.indexing_exprs[name]
  196. free_symbols = list(expr.free_symbols)
  197. if len(free_symbols) == 0:
  198. return ValueRanges(expr, expr)
  199. if expr in self.replacement_vals:
  200. return self.replacement_vals[expr]
  201. def replace_symbols_for_deriv(expr, ignore_mod=False):
  202. # for the purposes of finding local, minimum, maximum, assume smoothness
  203. def mod_indexing_rep(x, y, z):
  204. if z.is_constant():
  205. return x / y
  206. # never really happens, we'll bail on optimizing
  207. return (x / y) % z
  208. def indexing_div_rep(x, y):
  209. return x / y
  210. return expr.replace(ModularIndexing, mod_indexing_rep).replace(
  211. FloorDiv, indexing_div_rep
  212. )
  213. symbols = expr.free_symbols
  214. monotonic_increasing = []
  215. monotonic_decreasing = []
  216. other_symbols = []
  217. expr_for_deriv = replace_symbols_for_deriv(expr, True)
  218. for symbol in symbols:
  219. diff = sympy.diff(expr_for_deriv, symbol)
  220. if diff.is_positive:
  221. monotonic_increasing.append(symbol)
  222. elif diff.is_positive is False: # can return None
  223. monotonic_decreasing.append(symbol)
  224. else:
  225. other_symbols.append(symbol)
  226. if not other_symbols:
  227. max_val = sympy_subs(
  228. expr,
  229. {
  230. k: (v.upper if k in monotonic_increasing else v.lower)
  231. for k, v in self.replacement_vals.items()
  232. },
  233. )
  234. min_val = sympy_subs(
  235. expr,
  236. {
  237. k: (v.lower if k in monotonic_increasing else v.upper)
  238. for k, v in self.replacement_vals.items()
  239. },
  240. )
  241. return ValueRanges(min_val, max_val)
  242. else:
  243. # bail on optimizing, have not run into this yet
  244. return ValueRanges(-math.inf, math.inf)
  245. def indexing_dtype_strength_reduction(loop_body: LoopBody):
  246. """
  247. Performs Value Range Analysis on LoopBody's fx graph to reduce precision of
  248. intermediaries from int64 to int32
  249. """
  250. indices = dict(loop_body.var_ranges)
  251. indexing = dict(loop_body.indexing_exprs)
  252. with V.set_ops_handler(ValueRangeAnalysis()):
  253. OptimizeIndexing(loop_body, indices, indexing).run()