sizevars.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612
  1. import dataclasses
  2. import functools
  3. import itertools
  4. import logging
  5. from typing import Callable, Dict, List, Tuple
  6. import sympy
  7. from sympy import Expr
  8. from torch.fx.experimental.symbolic_shapes import ShapeEnv
  9. from . import ir
  10. from .codegen.common import IndentedBuffer
  11. from .utils import sympy_subs, sympy_symbol, VarRanges
  12. from .virtualized import V
  13. log = logging.getLogger(__name__)
  14. @dataclasses.dataclass
  15. class ZeroGuard:
  16. """
  17. An expression we should check equals zero.
  18. Guards are currently not checked. Plan to add this later.
  19. """
  20. expr: Expr
  21. @dataclasses.dataclass
  22. class PositiveGuard:
  23. """
  24. An expression we should check for > 0
  25. Guards are currently not checked. Plan to add this later.
  26. """
  27. expr: Expr
  28. class SizeVarAllocator:
  29. def __init__(self, shape_env=None):
  30. super().__init__()
  31. if shape_env is None:
  32. shape_env = ShapeEnv()
  33. self.shape_env = shape_env
  34. self.var_to_val = self.shape_env.var_to_val
  35. self.guards = []
  36. self.replacements: Dict[sympy.Symbol, Expr] = self.shape_env.replacements
  37. # maps of dynamic sizes that have to be precomputed on the host to the kernel args
  38. self.precomputed_replacements: Dict[Expr, sympy.Symbol] = dict()
  39. self.inv_precomputed_replacements: Dict[sympy.Symbol, Expr] = dict()
  40. self.need_seed = False
  41. self.stride_vars = self.make_stride_vars_cache()
  42. self.simplify_with_ranges = self.make_simplify_with_ranges_cache()
  43. self._simplify_loops = self.make_simplify_loops_cache()
  44. self.declare = ""
  45. self.ending = ""
  46. self.as_strided = "as_strided"
  47. def seed(self):
  48. """
  49. Seed is a special variable used to hold the rng seed for a graph.
  50. Note this is only used by the CPU backend, we put seeds in a
  51. 1-element tensor for the CUDA backend.
  52. """
  53. self.need_seed = True
  54. return sympy_symbol("seed")
  55. def simplify(self, expr: Expr):
  56. return sympy.expand(expr).xreplace(self.replacements)
  57. def make_simplify_with_ranges_cache(self):
  58. """
  59. self._simplify_with_ranges() can be expensive, cache its results
  60. """
  61. cache = dict()
  62. replacement_count = len(self.replacements)
  63. def simplify_with_ranges(expr: Expr, var_ranges: VarRanges):
  64. nonlocal replacement_count
  65. if replacement_count != len(self.replacements):
  66. # new replacements invalidates cached results
  67. cache.clear()
  68. replacement_count = len(self.replacements)
  69. key = (expr, *var_ranges.items())
  70. result = cache.get(key, None)
  71. if result is None:
  72. result = self._simplify_with_ranges(expr, var_ranges)
  73. cache[key] = result
  74. return result
  75. return simplify_with_ranges
  76. def make_simplify_loops_cache(self):
  77. """
  78. self._simplify_with_ranges() can be expensive, cache its results
  79. """
  80. cache = dict()
  81. replacement_count = len(self.replacements)
  82. def simplify_loops(index_vars, sizes, index_formulas):
  83. nonlocal replacement_count
  84. if replacement_count != len(self.replacements):
  85. # new replacements invalidates cached results
  86. cache.clear()
  87. replacement_count = len(self.replacements)
  88. key = (*index_vars, *sizes, *index_formulas)
  89. result = cache.get(key, None)
  90. if result is None:
  91. result = self._simplify_loops_impl(index_vars, sizes, index_formulas)
  92. cache[key] = result
  93. return result
  94. return simplify_loops
  95. def _simplify_with_ranges(self, expr: Expr, var_ranges: VarRanges):
  96. """
  97. Simplify indexing expression with knowledge of the ranges of
  98. iteration variables.
  99. """
  100. from .ir import FloorDiv, ModularIndexing
  101. expr = join_dimensions(self.simplify(expr))
  102. original_expr = expr
  103. def remove_zero_terms(base, divisor):
  104. """Symbols smaller than the divisor are zero"""
  105. for v in base.free_symbols:
  106. if v in var_ranges:
  107. # var smaller than divisor can be removed
  108. # if the rest is guaranteed to be multiple of divisor
  109. rest = sympy.Wild("_rest", exclude=[v])
  110. m = base.match(v + rest)
  111. if m and v not in m[rest].free_symbols:
  112. gcd = sympy.gcd(m[rest], divisor)
  113. if gcd == divisor:
  114. if self.maybe_guard_leq(var_ranges[v], divisor):
  115. base = m[rest]
  116. return base
  117. def visit_indexing_div(base, divisor):
  118. return FloorDiv(remove_zero_terms(base, divisor), divisor)
  119. def visit_modular_indexing(base, divisor, modulus):
  120. base = remove_zero_terms(base, divisor)
  121. if isinstance(base, ModularIndexing):
  122. # for modular indexing, biggest values from the ranges don't necessarily result in
  123. # the biggest result, the biggest result is modulus - 1
  124. base_s = base.args[2] - 1
  125. elif not base.has(ModularIndexing):
  126. # actual iteration range is to size-1
  127. iter_ranges_zero = {k: 0 for k, v in var_ranges.items()}
  128. base_lowest = sympy_subs(base, iter_ranges_zero)
  129. if self.maybe_guard_lt(base_lowest, 0):
  130. # can't replace with indexing div if base can be negative
  131. return ModularIndexing(base, divisor, modulus)
  132. iter_ranges = {k: v - 1 for k, v in var_ranges.items()}
  133. base_s = sympy_subs(base, iter_ranges)
  134. else:
  135. base_s = base
  136. if self.maybe_guard_lt(base_s, modulus * divisor):
  137. return FloorDiv(base, divisor)
  138. return ModularIndexing(base, divisor, modulus)
  139. if expr.has(ModularIndexing):
  140. expr = expr.replace(
  141. ModularIndexing(
  142. sympy.Wild("base"),
  143. sympy.Wild("divisor"),
  144. sympy.Wild("modulus"),
  145. ),
  146. visit_modular_indexing,
  147. )
  148. if expr.has(FloorDiv):
  149. expr = expr.replace(
  150. FloorDiv(
  151. sympy.Wild("base"),
  152. sympy.Wild("divisor"),
  153. ),
  154. visit_indexing_div,
  155. )
  156. if expr != original_expr:
  157. return self._simplify_with_ranges(expr, var_ranges)
  158. return expr
  159. def _simplify_loops_impl(self, index_vars, sizes, index_formulas):
  160. """
  161. Try to remove as many axis from loop iterations as possible, by:
  162. 1) removing size==1 dimensions
  163. 2) fuse contiguous dimensions into a single loop
  164. If channel_last = True, we will prevent the last dim fused with other dims
  165. """
  166. sizes = list(map(self.simplify, sizes))
  167. strides = [self.stride_vars(x, index_vars) for x in index_formulas]
  168. assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0]))
  169. for i in range(len(sizes)):
  170. if sizes[i] == 1:
  171. # remove dim
  172. sizes[i] = None
  173. def can_merge_dims(a, b):
  174. for k in range(len(strides)):
  175. if self.simplify(strides[k][a] * sizes[a]) == self.simplify(
  176. strides[k][b]
  177. ):
  178. # approximate test passed, try sound version
  179. va = index_vars[a]
  180. vb = index_vars[b]
  181. v = sympy_symbol("_merge_tester")
  182. expr1 = sympy_subs(index_formulas[k], {va: v * sizes[a], vb: 0})
  183. expr2 = sympy_subs(index_formulas[k], {va: 0, vb: v})
  184. if self.simplify(expr1) == self.simplify(expr2):
  185. continue
  186. return False
  187. return True
  188. changed = True
  189. while changed:
  190. changed = False
  191. for i, j in itertools.product(
  192. reversed(range(len(sizes))), reversed(range(len(sizes)))
  193. ):
  194. if i == j or sizes[i] is None or sizes[j] is None:
  195. continue
  196. if can_merge_dims(i, j):
  197. changed = True
  198. sizes[i] = sizes[i] * sizes[j]
  199. sizes[j] = None
  200. def reindex(index):
  201. it = list(reversed(index))
  202. new_index = []
  203. for size in sizes:
  204. if size is None:
  205. new_index.append(sympy.Integer(0))
  206. else:
  207. new_index.append(it.pop())
  208. assert not it
  209. return new_index
  210. def prune(index):
  211. assert len(index) == len(sizes)
  212. return [i for i, s in zip(index, sizes) if s is not None]
  213. return [x for x in sizes if x is not None], reindex, prune
  214. def guard_equals(self, left: Expr, right: Expr) -> Expr:
  215. assert self.shape_env.evaluate_expr(sympy.Eq(left, right))
  216. return left
  217. def maybe_guard_equals(self, left: Expr, right: Expr) -> bool:
  218. """if left==right, guard on that fact and return true"""
  219. if left == right:
  220. return True
  221. if self.size_hint(left - right) == 0:
  222. self.guard_equals(left, right)
  223. return True
  224. return False
  225. def maybe_guard_list_equals(self, left: List[Expr], right: List[Expr]) -> bool:
  226. """if left==right, guard on that fact and return true"""
  227. if len(left) != len(right):
  228. return False
  229. if all(self.size_hint(a - b) == 0 for a, b in zip(left, right)):
  230. for a, b in zip(left, right):
  231. self.guard_equals(a, b)
  232. return True
  233. return False
  234. def maybe_guard_leq(self, left: Expr, right: Expr) -> bool:
  235. try:
  236. if self.size_hint(left) > self.size_hint(right):
  237. return False
  238. except TypeError:
  239. return False
  240. self.guard_leq(left, right)
  241. return True
  242. def maybe_guard_lt(self, left: Expr, right: Expr) -> bool:
  243. try:
  244. if self.size_hint(left) >= self.size_hint(right):
  245. return False
  246. except TypeError:
  247. return False
  248. self.guard_lt(left, right)
  249. return True
  250. def guard_leq(self, left: Expr, right: Expr) -> None:
  251. return self.guard_lt(left, right + 1)
  252. def guard_lt(self, left: Expr, right: Expr) -> None:
  253. expr = self.simplify(right - left)
  254. assert self.size_hint(expr) > 0
  255. if len(expr.free_symbols) == 0:
  256. return
  257. if "-" in str(expr):
  258. # all vars are positive, so needs a minus sign to get negative values
  259. self.guards.append(PositiveGuard(expr))
  260. def guard_min(self, left: Expr, right: Expr) -> Expr:
  261. """return the smaller of left and right, and guard on that choice"""
  262. lv = self.size_hint(left)
  263. rv = self.size_hint(right)
  264. if lv == rv:
  265. return self.guard_equals(left, right)
  266. elif lv < rv:
  267. self.guard_lt(left, right)
  268. return left
  269. else:
  270. self.guard_lt(right, left)
  271. return right
  272. def guard_max(self, left: Expr, right: Expr) -> Expr:
  273. """return the larger of left and right, and guard on that choice"""
  274. return -self.guard_min(-left, -right)
  275. def maybe_guard_multiple_of(self, numerator: Expr, denominator: Expr) -> bool:
  276. """if denominator divides numerator, return True and guard on that fact"""
  277. if sympy.gcd(numerator, denominator) == denominator:
  278. # can prove it symbolically
  279. return True
  280. if self.size_hint(numerator) % self.size_hint(denominator) == 0:
  281. self.guard_equals(numerator % denominator, 0)
  282. return True
  283. return False
  284. def guard_static_shape(self, left: Expr) -> int:
  285. right = self.size_hint(left)
  286. self.guard_equals(left, sympy.Integer(right))
  287. return int(right)
  288. def __getitem__(self, val: int) -> Expr:
  289. return self.shape_env.duck_int(val)
  290. def size_hint(self, expr: Expr) -> int:
  291. out = sympy_subs(sympy.expand(expr), self.var_to_val)
  292. return int(out)
  293. def size_hints(self, exprs: List[Expr]) -> int:
  294. return tuple(self.size_hint(x) for x in exprs)
  295. def _lru_cache(self, fn, maxsize=None):
  296. """
  297. Wrapper around functools.lru_cache that clears when replacements
  298. has been invalidated.
  299. """
  300. fn_cache = functools.lru_cache(maxsize)(fn)
  301. prior_len = len(self.replacements)
  302. @functools.wraps(fn)
  303. def wrapper(*args, **kwargs):
  304. nonlocal prior_len
  305. if prior_len != len(self.replacements):
  306. prior_len = len(self.replacements)
  307. fn_cache.cache_clear()
  308. return fn_cache(*args, **kwargs)
  309. return wrapper
  310. def make_stride_vars_cache(self):
  311. cache = self._lru_cache(self._stride_vars)
  312. def stride_vars(index: Expr, vars: List[sympy.Symbol]) -> List[Expr]:
  313. return cache(index, tuple(vars))
  314. return stride_vars
  315. def _stride_vars(self, index: Expr, vars: List[sympy.Symbol]) -> List[Expr]:
  316. """Convert an indexing expression back into strides
  317. NOTE: This is only valid if the index is a standard strided offset
  318. calculation. e.g. 10 * ModularIndexing(i0 + 1, 1, 2) would give a
  319. stride of -10 because the index wraps around after the first element
  320. """
  321. strides = []
  322. index = self.simplify(index)
  323. # remove any offset
  324. index = index - sympy_subs(index, {v: sympy.Integer(0) for v in vars if v != 0})
  325. for i in range(len(vars)):
  326. # drop all the other dims
  327. index_dim = sympy_subs(
  328. index,
  329. {
  330. vars[j]: sympy.Integer(0)
  331. for j in range(len(vars))
  332. if i != j and vars[j] != 0
  333. },
  334. )
  335. v = vars[i]
  336. if v == 0:
  337. strides.append(sympy.Integer(0))
  338. else:
  339. # TODO(jansel): should we use sympy.diff here?
  340. strides.append(
  341. sympy_subs(index_dim, {v: sympy.Integer(1)})
  342. - sympy_subs(index_dim, {v: sympy.Integer(0)})
  343. )
  344. return strides
  345. def offset_var(self, index: Expr, vars: List[sympy.Symbol]) -> Expr:
  346. """Extract offset part of an indexing expression"""
  347. index = self.simplify(index)
  348. return sympy_subs(index, {v: sympy.Integer(0) for v in vars if v != 0})
  349. def stride_hints(self, index: Expr, vars: List[sympy.Symbol]) -> List[int]:
  350. for v in index.free_symbols:
  351. if v.name.startswith("indirect"):
  352. index = sympy_subs(index, {v: 0})
  353. result = []
  354. for s in self.stride_vars(index, vars):
  355. try:
  356. result.append(self.size_hint(s))
  357. except TypeError:
  358. result.append(0)
  359. return result
  360. def stride_order(self, index: Expr, vars: List[sympy.Symbol]) -> List[int]:
  361. strides = tuple(
  362. map(lambda x: abs(x), self.stride_hints(index, vars))
  363. ) # lambda to placate mypy
  364. order = list(range(len(strides)))
  365. order.sort(key=lambda x: (strides[x] == 0, strides[x]))
  366. return order
  367. def lookup_precomputed_size(self, expr: Expr):
  368. if expr not in self.precomputed_replacements:
  369. sym = sympy_symbol(f"ps{len(self.precomputed_replacements)}")
  370. self.precomputed_replacements[expr] = sym
  371. self.inv_precomputed_replacements[sym] = expr
  372. return self.precomputed_replacements[expr]
  373. def codegen(self, code: IndentedBuffer, graph_inputs: Dict[str, ir.Buffer]):
  374. """Assign all symbolic shapes to locals"""
  375. if self.need_seed:
  376. code.writeline(
  377. "seed = torch.randint(2**31, size=(), dtype=torch.int32).item()"
  378. )
  379. @functools.lru_cache(None)
  380. def sizeof(name):
  381. code.writeline(f"{self.declare}{name}_size = {name}.size(){self.ending}")
  382. return f"{name}_size"
  383. @functools.lru_cache(None)
  384. def strideof(name):
  385. code.writeline(
  386. f"{self.declare}{name}_stride = {name}.stride(){self.ending}"
  387. )
  388. return f"{name}_stride"
  389. # Assign all symbolic shapes needed to local variables
  390. needed = set(self.var_to_val.keys()) - set(self.replacements.keys())
  391. for name, value in graph_inputs.items():
  392. shapes = value.get_size()
  393. for dim, shape in enumerate(shapes):
  394. shape = self.simplify(shape)
  395. if shape in needed:
  396. needed.remove(shape)
  397. code.writeline(
  398. f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}"
  399. )
  400. for name, value in graph_inputs.items():
  401. shapes = value.get_stride()
  402. for dim, shape in enumerate(shapes):
  403. shape = self.simplify(shape)
  404. if shape in needed:
  405. needed.remove(shape)
  406. code.writeline(
  407. f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}"
  408. )
  409. def codegen_precomputed_sizes(self, code: IndentedBuffer):
  410. from .codegen.wrapper import pexpr
  411. for sym, expr in self.inv_precomputed_replacements.items():
  412. code.writeline(f"{self.declare}{sym} = {pexpr(expr)}")
  413. def codegen_sizevar(self, x: Expr) -> str:
  414. from .codegen.wrapper import pexpr
  415. return pexpr(self.simplify(x))
  416. def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
  417. parts = list(map(self.codegen_sizevar, shape))
  418. if len(parts) == 0:
  419. return "()"
  420. if len(parts) == 1:
  421. return f"({parts[0]}, )"
  422. return f"({', '.join(parts)})"
  423. def codegen_benchmark_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
  424. return self.codegen_shape_tuple(shape)
  425. def join_dimensions(expr: Expr) -> Expr:
  426. from .ir import ModularIndexing
  427. if not isinstance(expr, sympy.Add) or not expr.has(ModularIndexing):
  428. return expr # fast exit path
  429. return _join_dimensions_cached(expr)
  430. @functools.lru_cache(256)
  431. def _join_dimensions_cached(expr: Expr) -> Expr:
  432. """
  433. ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4)
  434. becomes
  435. ModularIndexing(i0, 1, 128)
  436. ModularIndexing(i0, 1, 32) + 32 * FloorDiv(i0, 32)
  437. becomes i0
  438. This type of pattern can come from view operations
  439. """
  440. from .ir import FloorDiv, ModularIndexing
  441. assert isinstance(expr, sympy.Add)
  442. scale = sympy.Wild("scale", exclude=[0])
  443. base = sympy.Wild("base")
  444. divisor = sympy.Wild("divisor")
  445. mod1 = sympy.Wild("modulus")
  446. mod2 = sympy.Wild("modulus2")
  447. for term1 in expr.args:
  448. m1 = term1.match(scale * ModularIndexing(base, divisor, mod1))
  449. if m1:
  450. for term2 in expr.args:
  451. m2 = term2.match(
  452. m1[scale]
  453. * m1[mod1]
  454. * ModularIndexing(m1[base], m1[divisor] * m1[mod1], mod2)
  455. )
  456. if m2 and term1 != term2:
  457. expr = join_dimensions(
  458. expr
  459. - term1
  460. - term2
  461. + m1[scale]
  462. * ModularIndexing(m1[base], m1[divisor], m1[mod1] * m2[mod2])
  463. )
  464. return expr
  465. for term1 in expr.args:
  466. m1 = term1.match(scale * ModularIndexing(base, divisor, mod1))
  467. if m1:
  468. for term2 in expr.args:
  469. m2 = term2.match(
  470. m1[scale] * m1[mod1] * FloorDiv(m1[base], m1[divisor] * m1[mod1])
  471. )
  472. if m2 is not None: # in case of success we get an empty dict here
  473. expr = join_dimensions(
  474. expr
  475. - term1
  476. - term2
  477. + m1[scale] * FloorDiv(m1[base], m1[divisor])
  478. )
  479. return expr
  480. return expr
  481. class CppSizeVarAllocator(SizeVarAllocator):
  482. def __init__(self, shape_env=None):
  483. super().__init__(shape_env)
  484. self.declare = "auto "
  485. self.ending = ";"
  486. self.as_strided = "at::as_strided"
  487. def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
  488. parts = list(map(self.codegen_sizevar, shape))
  489. if len(parts) == 0:
  490. return "{}"
  491. if len(parts) == 1:
  492. return f"{{{parts[0]}, }}"
  493. return f"{{{', '.join(parts)}}}"
  494. def codegen_benchmark_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
  495. return super().codegen_shape_tuple(shape)
  496. class SimplifyIndexing(V.WrapperHandler): # type: ignore[name-defined]
  497. """
  498. A wrapper around .virtualize.ops that uses var range information to
  499. simplify ir.ModularIndexing/ir.FloorDiv.
  500. """
  501. def __init__(self, inner, var_ranges: VarRanges):
  502. super().__init__(inner)
  503. self.name = "SimplifyIndexing"
  504. self._simplify: Callable[
  505. [Expr], Expr
  506. ] = lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges)
  507. def load(self, name: str, index: sympy.Expr):
  508. return self._inner.load(name, self._simplify(index))
  509. def store(self, name, index, value, mode=None):
  510. return self._inner.store(name, self._simplify(index), value, mode=mode)
  511. def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
  512. return self._inner.reduction(
  513. name, dtype, src_dtype, reduction_type, self._simplify(index), value
  514. )
  515. def index_expr(self, index, dtype):
  516. return self._inner.index_expr(self._simplify(index), dtype)