symbolic_shapes.py 57 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498
  1. import torch
  2. from typing import Set, Dict, List, Type, Optional, cast, Union
  3. import sys
  4. import builtins
  5. import itertools
  6. import operator
  7. import math
  8. import functools
  9. import threading
  10. from contextlib import contextmanager
  11. from functools import lru_cache
  12. import traceback
  13. import collections
  14. import textwrap
  15. import logging
  16. # NB: The sym_* functions are used via getattr() and must be imported here.
  17. from torch import SymInt, SymFloat, SymBool, sym_not, sym_float, sym_max, sym_min # noqa: F401
  18. from torch._guards import ShapeGuard, Source
  19. SymTypes = (SymInt, SymFloat, SymBool)
  20. log = logging.getLogger(__name__)
  21. class GuardOnDataDependentSymNode(RuntimeError):
  22. pass
  23. try:
  24. import sympy # type: ignore[import]
  25. from sympy.printing.precedence import precedence # type: ignore[import] # noqa: F401
  26. from sympy.printing.str import StrPrinter # type: ignore[import]
  27. from sympy.core.logic import fuzzy_and, fuzzy_or # type: ignore[import]
  28. HAS_SYMPY = True
  29. except ImportError:
  30. HAS_SYMPY = False
  31. aten = torch._ops.ops.aten # type: ignore[has-type]
  32. __all__ = [
  33. "has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv",
  34. "SymDispatchMode", "FloorDiv", "guard_int", "guard_float", "guard_scalar", "wrap_node",
  35. "method_to_operator", "hint_int", "SYMPY_INTERP",
  36. ]
  37. SYM_FUNCTION_MODE = None
  38. # We don't bother with the metaclass as all of the dispatching logic happens
  39. # entirely from Python
  40. #
  41. # Didn't bother with ancestors for now, unlikely to have multiple modes for
  42. # symints right now
  43. # SymDispatchMode gets invoked whenever an operation is processed on
  44. # a PySymInt. When this occurs, you get called at __sym_dispatch__
  45. # with the operation in question. This is symmetric to TorchDispatchMode
  46. # but with some caveats:
  47. #
  48. # - In TorchDispatchMode, you get the same arguments as what a user
  49. # invoked your API with; e.g., if you call torch.ops.aten.foo(a, b),
  50. # you get (a, b) as args to your call. In SymDispatchMode, if
  51. # you call a + b (where a and b are SymInts), you will get
  52. # (a.node, b.node) as your args (these are PySymInts)
  53. #
  54. # - SymInt/PySymInt don't have FX proxy support (unlike, e.g., Tensor).
  55. # So you have to manually call Tracer/create_node to write into
  56. # the graph. See ProxySymDispatchMode for an example
  57. #
  58. class SymDispatchMode:
  59. def __sym_dispatch__(self, func, types, args, kwargs):
  60. raise NotImplementedError()
  61. def __enter__(self):
  62. global SYM_FUNCTION_MODE
  63. old = SYM_FUNCTION_MODE
  64. if hasattr(self, "inner"):
  65. raise RuntimeError(f"{self} has already been used as a mode. Please use a fresh version")
  66. else:
  67. self.inner = old
  68. SYM_FUNCTION_MODE = self
  69. return self
  70. def __exit__(self, exc_type, exc_val, exc_tb):
  71. global SYM_FUNCTION_MODE
  72. SYM_FUNCTION_MODE = self.inner
  73. def has_symbolic_sizes_strides(elem):
  74. return elem._has_symbolic_sizes_strides
  75. def create_contiguous(shape):
  76. strides = [1]
  77. for dim in reversed(shape[:-1]):
  78. strides.append(dim * strides[-1])
  79. return list(reversed(strides))
  80. def _handle_sym_dispatch(func, args, kwargs):
  81. global SYM_FUNCTION_MODE
  82. mode = SYM_FUNCTION_MODE
  83. assert mode
  84. SYM_FUNCTION_MODE = mode.inner
  85. try:
  86. # TODO: properly compute types
  87. types: List[Type] = []
  88. return mode.__sym_dispatch__(func, types, args, kwargs)
  89. finally:
  90. SYM_FUNCTION_MODE = mode
  91. def hint_int(a):
  92. if isinstance(a, torch.SymInt):
  93. return a.node.require_hint()
  94. assert type(a) is int, a
  95. return a
  96. def guard_scalar(a):
  97. if isinstance(a, (SymBool, bool)):
  98. return guard_bool(a)
  99. elif isinstance(a, (SymInt, int)):
  100. return guard_int(a)
  101. elif isinstance(a, (SymFloat, float)):
  102. return guard_float(a)
  103. else:
  104. raise AssertionError(f"unrecognized scalar {a}")
  105. def guard_bool(a):
  106. if isinstance(a, SymBool):
  107. return a.node.guard_bool("", 0) # NB: uses Python backtrace
  108. assert type(a) is bool, a
  109. return a
  110. def guard_int(a):
  111. if isinstance(a, SymInt):
  112. return a.node.guard_int("", 0) # NB: uses Python backtrace
  113. assert type(a) is int, a
  114. return a
  115. def guard_float(a):
  116. if isinstance(a, SymFloat):
  117. return a.node.guard_float("", 0) # NB: uses Python backtrace
  118. assert isinstance(a, float), a
  119. return a
  120. # Drop in replacement for math.sqrt
  121. def sym_sqrt(a):
  122. if hasattr(a, '__sym_sqrt__'):
  123. return a.__sym_sqrt__()
  124. return math.sqrt(a)
  125. def to_node(self, num):
  126. if isinstance(num, SymTypes):
  127. return num.node
  128. elif type(num) is bool:
  129. return self.wrap_bool(num)
  130. elif type(num) is int:
  131. return self.wrap_int(num)
  132. elif type(num) is float:
  133. return self.wrap_float(num)
  134. else:
  135. # NotImplemented is important so that Python tries the
  136. # other magic method
  137. return NotImplemented
  138. # Given a GraphModule, return all the FakeTensors for all the placeholders
  139. def fx_placeholder_vals(gm):
  140. return [n.meta['val'] for n in gm.graph.nodes if n.op == "placeholder"]
  141. def fx_placeholder_targets(gm):
  142. return [n.target for n in gm.graph.nodes if n.op == "placeholder"]
  143. # Given a GraphModule and arguments to run it with, evaluate that the guards
  144. # for its associated ShapeEnv are satisfied by the passed arguments. This
  145. # WILL check for duck sizing.
  146. def eval_guards(gm, *args):
  147. return gm.shape_env.evaluate_guards_for_args(fx_placeholder_vals(gm), args)
  148. def bind_symbols(gm, *args):
  149. return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args)
  150. # TODO: An incomplete list
  151. # 1. Set variables to be equal when we do equality
  152. # 2. Specialize on 0/1 when we do subtraction
  153. class SymNode:
  154. """
  155. This is a type erased SymInt/SymFloat which we use to do actual operations.
  156. End users don't touch this. Magic methods are NOT defined on this object.
  157. """
  158. def __init__(self, expr, shape_env, pytype, hint: Optional[Union[int, float]], constant=None):
  159. self._expr = expr
  160. self.shape_env = shape_env
  161. self.pytype = pytype
  162. # What's the difference between hint and constant?
  163. #
  164. # - A constant is known to be invariant across invocations of the model;
  165. # it will always be this value. We only really know this when we
  166. # encounter an honest-to-goodness literal (when wrapping it into
  167. # a SymNode, we set constant.) Most of the time, constant is None
  168. #
  169. # - A hint is a *particular* value from the particular run we are
  170. # tracing, but it may vary the next time around. It's useful to
  171. # keep this around, as if we need a concrete value from a SymNode,
  172. # we will return the hint and guard on the expression that produced
  173. # it giving the same hint next time around. The hint is not
  174. # guaranteed to be set either: if you have an unbacked SymNode,
  175. # there won't be any hint; it was the result of some tensor-dependent
  176. # computation, but we don't know what it actually is because we
  177. # haven't actually run the tensor computation.
  178. #
  179. # hint_expr is only set if we don't have a hint. When it is set, it
  180. # contains the expression which contains the unbacked symnodes that,
  181. # if constrained, would allow this expression to be hinted again.
  182. if hint is None:
  183. self._hint_expr = self.expr.xreplace(shape_env.var_to_val)
  184. self._hint = None
  185. self._update_hint() # check if the replacement actually was enough
  186. else:
  187. self._hint_expr = None
  188. self._hint = hint
  189. self.constant: Optional[Union[int, float, bool]] = constant
  190. @property
  191. def expr(self):
  192. self._update_expr()
  193. return self._expr
  194. # Check if we have replacements hint_expr that would allow us to
  195. # simplify it into a hint
  196. def _update_hint(self):
  197. if self._hint_expr.free_symbols <= self.shape_env.replacements.keys():
  198. self._hint = self.pytype(self.shape_env.replace(self._hint_expr))
  199. self._hint_expr = None
  200. @property
  201. def hint(self):
  202. if self._hint is None:
  203. self._update_hint()
  204. return self._hint
  205. def require_hint(self):
  206. if self._hint is None:
  207. self._update_hint()
  208. if self._hint is None:
  209. raise self.shape_env._make_data_dependent_error(self._hint_expr)
  210. else:
  211. return self._hint
  212. else:
  213. return self._hint
  214. def _update_expr(self):
  215. self._expr = self.shape_env.replace(self._expr)
  216. def is_int(self):
  217. return self.pytype is int
  218. def is_float(self):
  219. return self.pytype is float
  220. def is_bool(self):
  221. return self.pytype is bool
  222. def wrap_int(self, num):
  223. assert type(num) is int
  224. return SymNode(sympy.Integer(num), self.shape_env, int, num, constant=num)
  225. def wrap_float(self, num):
  226. assert type(num) is float
  227. return SymNode(sympy.Float(num), self.shape_env, float, num, constant=num)
  228. def wrap_bool(self, num):
  229. assert type(num) is bool
  230. return SymNode(sympy.true if num else sympy.false, self.shape_env, bool, num, constant=num)
  231. def clone(self):
  232. return self
  233. def str(self):
  234. return f"{self.expr}"
  235. def __str__(self):
  236. return self.str()
  237. def __repr__(self):
  238. return self.str()
  239. # These methods call the metaprogrammed methods, they're hand written
  240. # here so we get good stack traces
  241. def add(self, other) -> "SymNode": # noqa: F811
  242. return self._add(other) # type: ignore[attr-defined]
  243. def sub(self, other) -> "SymNode": # noqa: F811
  244. return self._sub(other) # type: ignore[attr-defined]
  245. def mul(self, other) -> "SymNode": # noqa: F811
  246. return self._mul(other) # type: ignore[attr-defined]
  247. def mod(self, other) -> "SymNode": # noqa: F811
  248. return self._mod(other) # type: ignore[attr-defined]
  249. def pow(self, other) -> "SymNode": # noqa: F811
  250. return self._pow(other) # type: ignore[attr-defined]
  251. def and_(self, other) -> "SymNode": # noqa: F811
  252. return self._and_(other) # type: ignore[attr-defined]
  253. def or_(self, other) -> "SymNode": # noqa: F811
  254. return self._or_(other) # type: ignore[attr-defined]
  255. def truediv(self, other) -> "SymNode": # noqa: F811
  256. return self._truediv(other) # type: ignore[attr-defined]
  257. def floordiv(self, other) -> "SymNode": # noqa: F811
  258. return self._floordiv(other) # type: ignore[attr-defined]
  259. def sym_not(self) -> "SymNode": # noqa: F811
  260. return self._sym_not() # type: ignore[attr-defined]
  261. def eq(self, other) -> "SymNode": # noqa: F811
  262. return self._eq(other) # type: ignore[attr-defined]
  263. def ne(self, other) -> "SymNode": # noqa: F811
  264. return self._ne(other) # type: ignore[attr-defined]
  265. def gt(self, other) -> "SymNode": # noqa: F811
  266. return self._gt(other) # type: ignore[attr-defined]
  267. def lt(self, other) -> "SymNode": # noqa: F811
  268. return self._lt(other) # type: ignore[attr-defined]
  269. def le(self, other) -> "SymNode": # noqa: F811
  270. return self._le(other) # type: ignore[attr-defined]
  271. def ge(self, other) -> "SymNode": # noqa: F811
  272. return self._ge(other) # type: ignore[attr-defined]
  273. def floor(self) -> "SymNode": # noqa: F811
  274. return self._floor() # type: ignore[attr-defined]
  275. def sym_float(self) -> "SymNode": # noqa: F811
  276. return self._sym_float() # type: ignore[attr-defined]
  277. def ceil(self) -> "SymNode": # noqa: F811
  278. return self._ceil() # type: ignore[attr-defined]
  279. def neg(self) -> "SymNode": # noqa: F811
  280. return self._neg() # type: ignore[attr-defined]
  281. def sym_min(self, other) -> "SymNode": # noqa: F811
  282. return self._sym_min(other) # type: ignore[attr-defined]
  283. def sym_max(self, other) -> "SymNode": # noqa: F811
  284. return self._sym_max(other) # type: ignore[attr-defined]
  285. def sym_sqrt(self) -> "SymNode": # noqa: F811
  286. return self._sym_sqrt() # type: ignore[attr-defined]
  287. def is_non_overlapping_and_dense_indicator(self, *args) -> "SymNode": # noqa: F811
  288. return self._is_non_overlapping_and_dense_indicator(*args) # type: ignore[attr-defined]
  289. # Make C++ happy
  290. def sym_or(self, other): # noqa: F811
  291. return self.or_(other)
  292. def sym_and(self, other): # noqa: F811
  293. return self.and_(other)
  294. # Today we error on calling int on a symbolic shape, as this is a very accessible footgun.
  295. def int_(self):
  296. if len(self.expr.free_symbols) == 0:
  297. return int(self.expr)
  298. raise RuntimeError(f"Trying to extract a concrete int out of a symbolic int {self.expr}")
  299. # You can manually trigger a guard with this function
  300. def guard_int(self, file, line):
  301. # TODO: use the file/line for some useful diagnostic on why a
  302. # guard occurred
  303. r = self.shape_env.evaluate_expr(self.expr, self.hint)
  304. try:
  305. return int(r)
  306. except Exception:
  307. log.warning(f"Failed to convert to int: {r}")
  308. raise
  309. def guard_float(self, file, line):
  310. # TODO: use the file/line for some useful diagnostic on why a
  311. # guard occurred
  312. r = self.shape_env.evaluate_expr(self.expr, self.hint)
  313. try:
  314. return float(r)
  315. except Exception:
  316. log.warning(f"Failed to convert to float: {r}")
  317. raise
  318. def guard_bool(self, file, line):
  319. # TODO: use the file/line for some useful diagnostic on why a
  320. # guard occurred
  321. r = self.shape_env.evaluate_expr(self.expr, self.hint)
  322. try:
  323. return bool(r)
  324. except Exception:
  325. log.warning(f"Failed to convert to bool: {r}")
  326. raise
  327. def bool_(self):
  328. return self.guard_bool("", 0)
  329. if HAS_SYMPY:
  330. # Overloaded to be compatible with regular Python.
  331. # https://github.com/pytorch/pytorch/issues/90900
  332. class Pow(sympy.Function):
  333. @classmethod
  334. def eval(cls, base, exp):
  335. if exp.is_zero:
  336. return sympy.Integer(1)
  337. elif base.is_zero and exp < 0:
  338. raise ZeroDivisionError(f"{base} cannot be raised to a negative power")
  339. else:
  340. return base ** exp
  341. # Overloaded to be compatible with regular Python.
  342. # https://github.com/pytorch/pytorch/issues/90900
  343. class TrueDiv(sympy.Function):
  344. @classmethod
  345. def eval(cls, base, divisor):
  346. if divisor.is_zero:
  347. raise ZeroDivisionError("division by zero")
  348. else:
  349. return base / divisor
  350. class FloorDiv(sympy.Function):
  351. """
  352. We maintain this so that:
  353. 1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b.
  354. 2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b)
  355. """
  356. nargs = (2,)
  357. precedence = 50 # precedence of mul # noqa: F811
  358. # Default return type for SymPy assumptions.
  359. # https://docs.sympy.org/latest/guides/assumptions.html#implementing-assumptions-handlers
  360. is_real = True
  361. @property
  362. def base(self):
  363. return self.args[0]
  364. @property
  365. def divisor(self):
  366. return self.args[1]
  367. def _sympystr(self, printer):
  368. base = printer.parenthesize(self.base, self.precedence)
  369. divisor = printer.parenthesize(self.divisor, self.precedence)
  370. return f"{base}//{divisor}"
  371. # SymPy assumptions based on argument types.
  372. def _eval_is_real(self):
  373. return fuzzy_or([self.base.is_real, self.divisor.is_real])
  374. def _eval_is_integer(self):
  375. return fuzzy_and([self.base.is_integer, self.divisor.is_integer])
  376. # Automatic evaluation.
  377. # https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval
  378. @classmethod
  379. def eval(cls, base, divisor):
  380. def check_supported_type(x):
  381. if (x.is_integer is False and x.is_real is False and x.is_complex) or x.is_Boolean:
  382. raise TypeError(
  383. f"unsupported operand type(s) for //: "
  384. f"'{type(base).__name__}' and '{type(divisor).__name__}'"
  385. f", expected integer or real")
  386. check_supported_type(base)
  387. check_supported_type(divisor)
  388. # We don't provide the same error message as in Python because SymPy
  389. # makes it difficult to check the types.
  390. if divisor.is_zero:
  391. raise ZeroDivisionError("division by zero")
  392. if base.is_zero:
  393. return sympy.S.Zero
  394. if base.is_integer and divisor == 1:
  395. return base
  396. if base.is_real and divisor == 1:
  397. return sympy.floor(base)
  398. if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
  399. return base // divisor
  400. if isinstance(base, (sympy.Integer, sympy.Float)) and isinstance(divisor, (sympy.Integer, sympy.Float)):
  401. return sympy.floor(base / divisor)
  402. if isinstance(base, FloorDiv):
  403. return FloorDiv(base.args[0], base.args[1] * divisor)
  404. if isinstance(base, sympy.Add):
  405. for a in base.args:
  406. gcd = sympy.gcd(a, divisor)
  407. if gcd == divisor:
  408. return FloorDiv(base - a, divisor) + a / gcd
  409. gcd = sympy.gcd(base, divisor)
  410. if gcd != 1:
  411. return FloorDiv(
  412. sympy.simplify(base / gcd), sympy.simplify(divisor / gcd)
  413. )
  414. class IsNonOverlappingAndDenseIndicator(sympy.Function):
  415. is_integer = True
  416. @classmethod
  417. def eval(cls, *args):
  418. assert len(args) % 2 == 0
  419. if all(isinstance(a, sympy.Integer) for a in args):
  420. dim = len(args) // 2
  421. sizes = args[0:dim]
  422. strides = args[dim:]
  423. return int(eval_is_non_overlapping_and_dense(
  424. [int(s) for s in sizes],
  425. [int(s) for s in strides]
  426. ))
  427. return None
  428. @lru_cache(256)
  429. def safe_expand(r):
  430. if hasattr(r, 'expand'):
  431. try:
  432. return sympy.expand(r)
  433. except RecursionError:
  434. log.warning(f"RecursionError in sympy.expand({r})")
  435. return r
  436. else:
  437. return r
  438. # Methods that have a `__foo__` as well as `__rfoo__`
  439. reflectable_magic_methods = {
  440. 'add': lambda a, b: a + b,
  441. 'sub': lambda a, b: a - b,
  442. 'mul': lambda a, b: a * b,
  443. 'mod': lambda a, b: a % b,
  444. 'pow': lambda a, b: Pow(a, b),
  445. 'and': lambda a, b: a & b,
  446. 'or': lambda a, b: a | b,
  447. 'truediv': lambda a, b: TrueDiv(a, b),
  448. 'floordiv': lambda a, b: FloorDiv(a, b),
  449. }
  450. def error():
  451. raise AssertionError("shouldn't be hit")
  452. def floor_ceil_helper(a, fn):
  453. if isinstance(a, sympy.Mul):
  454. aa = a.args
  455. if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer:
  456. coef = sympy.Integer(aa[0])
  457. if aa[0] == coef: # structural equality test
  458. return coef * aa[1]
  459. if isinstance(a, sympy.Float) and a == sympy.Integer(a) or isinstance(a, sympy.Integer):
  460. return sympy.Integer(a)
  461. return fn(a)
  462. def floor_impl(a):
  463. return floor_ceil_helper(a, sympy.floor)
  464. def ceil_impl(a):
  465. return floor_ceil_helper(a, sympy.ceiling)
  466. magic_methods = {
  467. **reflectable_magic_methods,
  468. 'sym_not': lambda a: ~a,
  469. 'eq': lambda a, b: sympy.Eq(a, b),
  470. 'ne': lambda a, b: sympy.Ne(a, b),
  471. 'gt': lambda a, b: sympy.Gt(a, b),
  472. 'lt': lambda a, b: sympy.Lt(a, b),
  473. 'le': lambda a, b: sympy.Le(a, b),
  474. 'ge': lambda a, b: sympy.Ge(a, b),
  475. 'floor': floor_impl,
  476. 'sym_float': lambda a: a, # Cannot use sympy.Float(a) here, coz it expects python literals
  477. 'ceil': ceil_impl,
  478. 'neg': lambda a: -a,
  479. 'sym_min': lambda a, b: sympy.Min(a, b),
  480. 'sym_max': lambda a, b: sympy.Max(a, b),
  481. 'sym_sqrt': lambda a: sympy.sqrt(a),
  482. }
  483. sizes_strides_methods = {
  484. 'is_non_overlapping_and_dense': lambda *args: IsNonOverlappingAndDenseIndicator(*args),
  485. }
  486. alternate_impl_if_hinted_methods = {
  487. "sym_min": builtins.min,
  488. "sym_max": builtins.max,
  489. }
  490. # TODO: Deduplicate this with torch/_prims_common/__init__.py
  491. def eval_is_non_overlapping_and_dense(sizes, strides):
  492. dim = len(sizes)
  493. # Short-circuits for tensors of rank one, which are
  494. # non-overlapping and "dense" if their stride is one
  495. # or it is a 0/1 element tensor
  496. if dim == 1:
  497. return strides[0] == 1 or sizes[0] < 2
  498. # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous
  499. # Sorts (length, stride) pairs by stride
  500. lengths_and_strides = sorted(
  501. zip(sizes, strides), key=operator.itemgetter(1)
  502. )
  503. # Unlike the C++ code, we don't move the 0/1 size dimensions to the
  504. # end. So we have to keep going for this code.
  505. expected_stride = 1
  506. for length, stride in lengths_and_strides:
  507. if length == 1:
  508. continue
  509. if stride != expected_stride:
  510. return False
  511. expected_stride *= length
  512. return True
  513. def is_non_overlapping_and_dense(sizes, strides):
  514. base = None
  515. for s in itertools.chain(sizes, strides):
  516. if isinstance(s, SymInt):
  517. base = s
  518. break
  519. assert base is not None
  520. return wrap_node(base.node.is_non_overlapping_and_dense(
  521. [to_node(base.node, s) for s in sizes],
  522. [to_node(base.node, s) for s in strides],
  523. ))
  524. unary_magic_methods = {
  525. 'sym_float',
  526. 'ceil',
  527. 'floor',
  528. 'neg',
  529. 'sym_sqrt',
  530. 'sym_not',
  531. }
  532. bool_magic_methods = {"and", "or", "sym_not"}
  533. magic_methods_on_math = {"ceil", "floor"}
  534. magic_methods_on_submodule = {"sym_float", "sym_sqrt", "sym_min", "sym_max", "sym_not"}
  535. magic_methods_on_operator_with_trailing_underscore = {"and", "or"}
  536. def method_to_operator(method):
  537. if method in magic_methods_on_operator_with_trailing_underscore:
  538. method_attr = f"{method}_"
  539. else:
  540. method_attr = method
  541. if method in magic_methods_on_submodule:
  542. op = getattr(torch.fx.experimental.symbolic_shapes, method_attr)
  543. elif method in magic_methods_on_math:
  544. op = getattr(math, method_attr)
  545. else:
  546. op = getattr(operator, method_attr)
  547. return op
  548. SYMPY_INTERP = {
  549. 'Eq': operator.eq,
  550. 'Ne': operator.ne,
  551. 'Gt': operator.gt,
  552. 'Lt': operator.lt,
  553. 'Le': operator.le,
  554. 'Ge': operator.ge,
  555. 'Min': min,
  556. 'Max': max,
  557. 'Mod': operator.mod,
  558. 'FloorDiv': operator.floordiv,
  559. 'TrueDiv': operator.truediv,
  560. 'floor': math.floor,
  561. 'ceiling': math.ceil,
  562. }
  563. always_float_magic_methods = {"truediv", "sym_float", "sym_sqrt", "pow"}
  564. always_int_magic_methods = {"ceil", "floor"}
  565. always_bool_magic_methods = {"eq", "ne", "gt", "lt", "le", "ge", "and", "or", "sym_not", "is_non_overlapping_and_dense"}
  566. def wrap_node(x):
  567. # TODO: let C++ also take advantage of this
  568. if isinstance(x, SymNode) and x.constant is not None:
  569. return x.constant
  570. if x.is_int():
  571. return SymInt(x)
  572. elif x.is_float():
  573. return SymFloat(x)
  574. elif x.is_bool():
  575. return SymBool(x)
  576. else:
  577. raise AssertionError(f"unrecognized return type {x}")
  578. def _make_node_magic(method, func):
  579. func = lru_cache(256)(func)
  580. if method in magic_methods_on_operator_with_trailing_underscore:
  581. method_attr = f"{method}_"
  582. else:
  583. method_attr = method
  584. def binary_magic_impl(self, other):
  585. op = method_to_operator(method)
  586. out_hint = None
  587. if self.hint is not None and other.hint is not None:
  588. out_hint = op(self.hint, other.hint)
  589. alternate_impl = alternate_impl_if_hinted_methods.get(method)
  590. if alternate_impl and out_hint is not None:
  591. return to_node(self, alternate_impl(wrap_node(self), wrap_node(other)))
  592. if SYM_FUNCTION_MODE:
  593. return to_node(self, _handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {}))
  594. assert isinstance(other, SymNode)
  595. other_expr = other.expr
  596. # TODO: consider constant prop here
  597. expr = self.shape_env.replace(self.expr)
  598. other_expr = self.shape_env.replace(other_expr)
  599. try:
  600. out = func(expr, other_expr)
  601. except Exception:
  602. log.warning(f"failed to eval {method}({expr}, {other_expr})")
  603. raise
  604. out = safe_expand(out)
  605. pytype: Type
  606. # This is not strictly correct. In Python, a**b may return complex when
  607. # a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This
  608. # returns a float while both arguments are ints: 2**(-1). Also, max and
  609. # min do not type promote. To avoid having data-dependent control flow
  610. # here, we just set the type to float if one of the args is a float. In
  611. # case of a type mismatch, we assume that it will be detected during
  612. # evaluation.
  613. if method in always_float_magic_methods:
  614. pytype = float
  615. elif method in always_bool_magic_methods:
  616. pytype = bool
  617. elif self.pytype is float or other.pytype is float:
  618. pytype = float
  619. else:
  620. pytype = self.pytype
  621. return SymNode(out, self.shape_env, pytype, out_hint)
  622. def unary_magic_impl(self):
  623. op = method_to_operator(method)
  624. if SYM_FUNCTION_MODE:
  625. return to_node(self, _handle_sym_dispatch(op, (wrap_node(self),), {}))
  626. # TODO: consider constant prop here
  627. expr = self.shape_env.replace(self.expr)
  628. try:
  629. out = func(expr)
  630. except Exception:
  631. log.warning(f"failed to eval {method}({expr})")
  632. raise
  633. out_hint = None
  634. if self.hint is not None:
  635. out_hint = op(self.hint)
  636. out = safe_expand(out)
  637. pytype: Type
  638. if method in always_int_magic_methods:
  639. pytype = int
  640. elif method in always_float_magic_methods:
  641. pytype = float
  642. else:
  643. pytype = self.pytype
  644. return SymNode(out, self.shape_env, pytype, out_hint)
  645. if method in unary_magic_methods:
  646. setattr(SymNode, f"_{method_attr}", unary_magic_impl)
  647. else:
  648. setattr(SymNode, f"_{method_attr}", binary_magic_impl)
  649. def _make_node_sizes_strides(method, func):
  650. # NB: don't LRU cache, lots of arguments
  651. def sizes_strides_impl(self, sizes, strides):
  652. op = getattr(sys.modules[__name__], method)
  653. if SYM_FUNCTION_MODE:
  654. r = _handle_sym_dispatch(op, ([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]), {})
  655. assert isinstance(r, SymBool), type(r)
  656. return r.node
  657. size_exprs = [s.expr for s in sizes]
  658. stride_exprs = [s.expr for s in strides]
  659. try:
  660. out = func(*size_exprs, *stride_exprs)
  661. except Exception:
  662. log.warning(f"failed to eval {method}(*{size_exprs}, *{stride_exprs})")
  663. raise
  664. hints = []
  665. out_hint = None
  666. for s in itertools.chain(sizes, strides):
  667. if s.hint is None:
  668. break
  669. hints.append(s.hint)
  670. else:
  671. out_hint = op(*hints)
  672. # bool is never expandable
  673. return SymNode(sympy.Eq(out, 1), self.shape_env, bool, out_hint)
  674. setattr(SymNode, f"_{method}", sizes_strides_impl)
  675. for method, func in magic_methods.items():
  676. _make_node_magic(method, func)
  677. for method, func in sizes_strides_methods.items():
  678. _make_node_sizes_strides(method, func)
  679. def _make_user_magic(method, user_type):
  680. # User magic takes care of wrapping the other operand into a node,
  681. # so that our internal logic can assume everything is nodes
  682. if method in magic_methods_on_operator_with_trailing_underscore:
  683. method_attr = f"{method}_"
  684. else:
  685. method_attr = method
  686. def unary_magic_impl(self):
  687. return wrap_node(getattr(self.node, method_attr)())
  688. def binary_magic_impl(self, other):
  689. other_node = to_node(self.node, other)
  690. if other_node is NotImplemented:
  691. return NotImplemented
  692. return wrap_node(getattr(self.node, method_attr)(other_node))
  693. def rbinary_magic_impl(self, other):
  694. other_node = to_node(self.node, other)
  695. if other_node is NotImplemented:
  696. return NotImplemented
  697. return wrap_node(getattr(other_node, method_attr)(self.node))
  698. if method in unary_magic_methods:
  699. setattr(user_type, f"__{method}__", unary_magic_impl)
  700. else:
  701. setattr(user_type, f"__{method}__", binary_magic_impl)
  702. if method in reflectable_magic_methods:
  703. setattr(user_type, f"__r{method}__", rbinary_magic_impl)
  704. for method, func in magic_methods.items():
  705. if method in bool_magic_methods:
  706. _make_user_magic(method, SymBool)
  707. else:
  708. _make_user_magic(method, SymInt)
  709. _make_user_magic(method, SymFloat)
  710. del method
  711. del func
  712. def _lru_cache(fn, maxsize=None):
  713. """
  714. Wrapper around lru_cache that clears when new info about shapes has been
  715. updated.
  716. Use lru_cache if the output is always the same, regardless of the
  717. constraints we know now (i.e. evaluate_expr)
  718. Use _lru_cache otherwise.
  719. """
  720. fn_cache = lru_cache(maxsize)(fn)
  721. prior_key = None
  722. @functools.wraps(fn)
  723. def wrapper(self, *args, **kwargs):
  724. nonlocal prior_key
  725. if prior_key != self._get_key():
  726. prior_key = self._get_key()
  727. fn_cache.cache_clear()
  728. return fn_cache(self, *args, **kwargs)
  729. wrapper.cache_info = fn_cache.cache_info # type: ignore[attr-defined]
  730. return wrapper
  731. if HAS_SYMPY:
  732. # This stub exists so we can easily add metadata to sympy symbols
  733. # NB: This inherits from Dummy, not Symbol, because Symbols with the same
  734. # name get interned. This is bad for us as we want the metadata
  735. # to vary across different invocations and not leak.
  736. class Symbol(sympy.Dummy):
  737. __slots__: List[str] = ['sources', 'stack']
  738. sources: List[Source]
  739. stack: Optional[str]
  740. def __new__(cls, *args, **kwargs):
  741. self = super().__new__(cls, *args, **kwargs)
  742. self.sources = []
  743. self.stack = None
  744. return self
  745. class ShapeGuardPrinter(StrPrinter):
  746. def __init__(
  747. self,
  748. symbol_to_source,
  749. source_ref,
  750. ):
  751. super().__init__()
  752. self.symbol_to_source = symbol_to_source
  753. self.source_ref = source_ref
  754. def _print_Symbol(self, expr) -> str:
  755. assert isinstance(expr, Symbol), str(type(expr))
  756. assert expr in self.symbol_to_source, (
  757. f"{expr} (could be from {[s.name() for s in expr.sources]}) "
  758. f"not in {self.symbol_to_source}"
  759. )
  760. return self.source_ref(self.symbol_to_source[expr][0])
  761. TLS = threading.local()
  762. class ShapeEnv:
  763. def __init__(self):
  764. self.guards: List[ShapeGuard] = []
  765. # Maps symbolic ints to their original concrete values
  766. # Currently populated from tensors
  767. self.var_to_val: Dict["sympy.Symbol", "sympy.Integer"] = {}
  768. # Maps from sympy ints to expressions representing them
  769. # Populated from equality guards (i.e. a.shape[0] == b.shape[0])
  770. self.replacements: Dict["sympy.Symbol", "sympy.Expr"] = {} #
  771. # Set holds a % b expressions that evaluate to 0.
  772. self.divisible: Set["sympy.Expr"] = set()
  773. # Duck-shaping says that if two input tensors have the same size,
  774. # they get assigned the same symbolic variable
  775. self.val_to_var: Dict[int, "sympy.Expr"] = {0: sympy.Integer(0), 1: sympy.Integer(1)}
  776. self.unbacked_symfloat_counter = itertools.count()
  777. self.unbacked_symint_counter = itertools.count()
  778. def _suppress_guards_tls(self):
  779. return getattr(TLS, "suppress_guards", False)
  780. @contextmanager
  781. def suppress_guards(self):
  782. TLS.suppress_guards = True
  783. try:
  784. yield
  785. finally:
  786. TLS.suppress_guards = False
  787. def _get_key(self):
  788. """
  789. Defines the current "state" of the guards we've accumulated in this ShapeEnv.
  790. Determines when we need to invalidate our cache
  791. """
  792. return (len(self.replacements), len(self.divisible))
  793. def create_symbolic_sizes_strides_storage_offset(self, ex: torch.Tensor, source: Source):
  794. """
  795. Returns a list of symbolic sizes and strides for the given tensor.
  796. We try our best to express stride in terms of the sizes, so as to not
  797. introduce new symbolic variables.
  798. """
  799. from torch._dynamo.source import TensorPropertySource, TensorProperty
  800. size = [
  801. self.create_symbol(
  802. val, TensorPropertySource(source, TensorProperty.SIZE, i)
  803. ) for i, val in enumerate(ex.size())
  804. ]
  805. stride: List[Optional[sympy.Expr]] = [None] * len(size)
  806. for i, val in enumerate(ex.stride()):
  807. if val in (0, 1):
  808. stride[i] = sympy.Integer(val)
  809. while any(x is None for x in stride):
  810. candidates = {
  811. ex.size(i) * ex.stride()[i]: size[i] * stride[i]
  812. for i in range(len(size))
  813. if stride[i] is not None and ex.stride()[i] >= 0
  814. }
  815. # iterate over unbound strides in sorted order
  816. val_list = sorted(
  817. [(ex.stride()[i], i) for i in range(len(stride)) if stride[i] is None]
  818. )
  819. for _, i in val_list:
  820. if stride[i] is None and ex.stride()[i] in candidates:
  821. stride[i] = candidates[ex.stride()[i]]
  822. candidates[ex.size(i) * ex.stride()[i]] = size[i] * stride[i]
  823. if any(x is None for x in stride):
  824. # bind the smallest unbound stride to a new variable
  825. val, i = min(
  826. [
  827. (ex.stride()[i], i)
  828. for i in range(len(stride))
  829. if stride[i] is None
  830. ]
  831. )
  832. stride[i] = self.create_symbol(
  833. val,
  834. TensorPropertySource(source, TensorProperty.STRIDE, i)
  835. )
  836. assert all(x is not None for x in stride)
  837. sym_size = [self.create_symintnode(i, hint=hint) for i, hint in zip(size, ex.size())]
  838. sym_stride = []
  839. for i, stride_expr in enumerate(stride):
  840. # NB: Don't duck size the stride; instead use the expression
  841. # we computed
  842. assert stride_expr is not None
  843. sym_stride.append(self.create_symintnode(stride_expr, hint=ex.stride(i)))
  844. sym_storage_offset = self.create_symintnode(self.create_symbol(
  845. ex.storage_offset(),
  846. TensorPropertySource(source, TensorProperty.STORAGE_OFFSET)
  847. ), hint=ex.storage_offset())
  848. return sym_size, sym_stride, sym_storage_offset
  849. # If you know what the current hint value of the SymInt to be created
  850. # is, pass it into hint. Otherwise, pass None and we will make our best
  851. # guess
  852. def create_symintnode(self, sym: "sympy.Expr", *, hint: Optional[int]):
  853. return SymInt(SymNode(sym, self, int, hint))
  854. def create_unbacked_symfloat(self):
  855. symbol = Symbol(f"f{next(self.unbacked_symfloat_counter)}")
  856. symbol.stack = ''.join(traceback.format_list(traceback.extract_stack()[:-1]))
  857. return SymFloat(SymNode(symbol, self, float, None))
  858. def create_unbacked_symint(self):
  859. symbol = Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
  860. symbol.stack = ''.join(traceback.format_list(traceback.extract_stack()[:-1]))
  861. return SymInt(SymNode(symbol, self, int, None))
  862. # This is guaranteed to return a symbol or its negation is a sympy.Symbol,
  863. # but there may be a replacement that allows it to be immediately
  864. # simplified
  865. def create_symbol(self, val: int, source: Source) -> "sympy.Expr":
  866. assert isinstance(source, Source), f"{type(source)} {source}"
  867. if not HAS_SYMPY:
  868. raise RuntimeError("Need sympy installed to create symbolic shapes")
  869. if val < 0:
  870. from torch._dynamo.source import NegateSource
  871. return -self.create_symbol(-val, NegateSource(source))
  872. # Now attempt to duck size this value
  873. # TODO: Use site has to duck size
  874. # TODO: Do this duck sizing lazily later
  875. # Create a duck sized int if necessary
  876. if val not in self.val_to_var:
  877. sympy_expr = Symbol(f"s{len(self.var_to_val)}", positive=True, integer=True)
  878. self.var_to_val[sympy_expr] = sympy.Integer(val)
  879. self.val_to_var[val] = sympy_expr
  880. # This implements duck-shaping: input sizes that match are assigned
  881. # the same symint
  882. r = self.duck_int(val)
  883. if isinstance(r, Symbol):
  884. r.sources.append(source)
  885. return r
  886. # Given a concrete integer value, return the duck sized symbol associated
  887. # with it; e.g., suppose we already have a tensor of size 3 in scope,
  888. # which was assigned s3, then shape_env.duck_int(3) we will get back s3.
  889. # This has some pretty tricky preconditions associated with it, so if
  890. # you are in a binding context, you probably wanted create_symbol instead.
  891. def duck_int(self, val):
  892. assert val in self.val_to_var, (
  893. "Direct call to duck_int MUST only duck size an integer values "
  894. "that have already produced by inputs (allocated "
  895. "by create_symbol), or we risk being unable to instantiate the "
  896. "symbolic variable later. However, at time of this call "
  897. f"val={val} was not duck sized. Bound duck sized integers: "
  898. f"{list(self.val_to_var.keys())}"
  899. )
  900. return self.val_to_var[val]
  901. # Generates a list of guards strings which, when evaluated in a context that
  902. # defines tensors for all the sources, returns True or False depending
  903. # on if the guards in the list evaluated to True or not. Primarily used by Dynamo,
  904. # but this is also helpful for manual testing of guards (see
  905. # evaluate_guards_for_args)
  906. #
  907. # For convenience in testing, a source is allowed to be a str,
  908. # in which case we will assume it is a LocalSource
  909. #
  910. # simplified lets you omit duck sizing, equality and 0/1 guards.
  911. # This is useful for testing when you don't care about the boilerplate
  912. # guards, and it may be helpful for user output too (be careful though;
  913. # some equality guards are nontrivial! It would be nice to get simplified
  914. # output to print them too). It's private because it's not
  915. # intended for normal use
  916. def produce_guards(self, placeholders, sources,
  917. source_ref=lambda n: n.name(), *, _simplified=False) -> List[str]:
  918. # It took a lot of sweat to figure out the algorithm here. Let's
  919. # explain how it works.
  920. #
  921. # The ShapeEnv lifecycle looks something like this:
  922. #
  923. # - For each input, you either generate a fresh Sympy symbol (s0) to
  924. # represent its value (a binding site), or you reuse some
  925. # preexisting symbol or expression, skipping the symbol allocation
  926. # (e.g., duck sizing to a preexisting symbol, or expressing a
  927. # stride as a multiplication of a separate stride and size.)
  928. # Naively, you might expect to bind a fresh Sympy symbol for
  929. # every input, but this is fairly wasteful as most of these
  930. # symbols immediately simplify away, and if you don't eagerly
  931. # specialize, e.g., 0/1 symbols, you end up with very complicated
  932. # expressions that are not optimizable in practice.
  933. #
  934. # - You perform some compute on these symbols, occasionally
  935. # introducing guards on boolean expressions on these symbols.
  936. # In particular, whenever we guard on equality (_maybe_guard_eq),
  937. # we can simplify shapes; e.g., when s0 == s1 * 2, we can now
  938. # replace all occurrences of s0 with s1 * 2. Sometimes, a
  939. # boolean expression evaluation doesn't introduce a guard, as
  940. # the guard is already entailed by the simplifications we have
  941. # applied.
  942. #
  943. # - In the end, you have a bunch of replacements (saying how to
  944. # simplify shapes) and a bunch of guards (all the equality guards
  945. # are trivial, because they're covered by the replacements).
  946. #
  947. # From the ShapeEnv, we must generate a Python expression that, when
  948. # evaluated on a set of inputs, tells us whether or not these boolean
  949. # expressions would have evaluated in the same way. However,
  950. # we cannot easily compute this, as we elide recording boolean
  951. # expressions when we think they are vacuously true. Thus, we seek
  952. # an approximation: we must generate an expression, if true, would have
  953. # produced an "equivalent" ShapeEnv, which would answer guard
  954. # expressions in the same way.
  955. #
  956. # Our notion of equivalence is a bit subtle. For example, consider
  957. # the ShapeEnv created from an input of size (5, 4) versus (4, 4)
  958. # (no other guards.) Duck sizing would generate (s0, s1) in the first
  959. # case but (s0, s0) in the second. We do NOT assume that size
  960. # variables are disjoint; so in fact a graph that assumes the input
  961. # could be (s0, s1) subsumes (s0, s0) (setting s0 == s1), but not
  962. # vice versa. However, consider an analogous case (1,) versus (2,).
  963. # Duck sizing generates (1,) and (s0,); the (s0,) graph does NOT
  964. # subsume the (1,) graph because we assume that any size variables
  965. # is NOT 0/1 (and make simplifications according to this; e.g., if
  966. # we queried s0 == 0, we would immediately return False without
  967. # returning a guard.)
  968. #
  969. # So, it is perhaps easier to flip things on their head: the guard
  970. # expressions we generate here say what simplifications are valid,
  971. # and what are not. Below, we explain each of the guard expressions
  972. # we generate
  973. # TODO: Make this more efficient by binding all the size/stride/offsets
  974. # to locals before performing tests on them.
  975. from torch._dynamo.source import NegateSource, TensorPropertySource, TensorProperty
  976. # Actual codegen must be delayed as we don't necessarily know what
  977. # the symbol mapping is
  978. input_guards = []
  979. symbol_to_source = collections.defaultdict(list)
  980. # How do we know what the value of s0 is? Fresh variables can only be
  981. # bound by inputs, so there MUST be some other input which binds the
  982. # variable. If there is no such input, this is an error in our
  983. # system. We record where all symbols come from, to help you diagnose
  984. # why those symbols didn't occur.
  985. #
  986. # In fact, generally speaking it is only possible for the "outermost"
  987. # user of a ShapeEnv to evaluate the guards, because some inputs may
  988. # not be available to inner levels. For example, Dynamo can guard on
  989. # tensors that never actually become graph arguments (they are
  990. # pruned). In this case, only Dynamo knows about these arguments.
  991. def track_symint(source, val):
  992. if isinstance(val, SymInt):
  993. s = val.node.expr
  994. if isinstance(s, sympy.Symbol):
  995. symbol_to_source[s].append(source)
  996. elif isinstance(-s, sympy.Symbol):
  997. symbol_to_source[-s].append(NegateSource(source))
  998. input_guards.append((source, s))
  999. else:
  1000. input_guards.append((source, sympy.Integer(val)))
  1001. for t, source in zip(placeholders, sources):
  1002. if isinstance(source, str):
  1003. from torch._dynamo.source import LocalSource
  1004. source = LocalSource(source)
  1005. assert isinstance(source, Source)
  1006. if t is None:
  1007. continue
  1008. if isinstance(t, SymInt):
  1009. track_symint(source, t)
  1010. continue
  1011. assert isinstance(t, torch.Tensor)
  1012. for i, s in enumerate(t.size()):
  1013. track_symint(TensorPropertySource(source, TensorProperty.SIZE, i), s)
  1014. for i, s in enumerate(t.stride()):
  1015. track_symint(TensorPropertySource(source, TensorProperty.STRIDE, i), s)
  1016. track_symint(TensorPropertySource(source, TensorProperty.STORAGE_OFFSET), t.storage_offset())
  1017. exprs = []
  1018. # 1. Every input must equal the final simplified symbolic expression
  1019. # stored on the placeholder. Given a placeholder (s0*2, s1),
  1020. # if we have an input (2, 3), we must show s0*2 == 2 and s1 == 3.
  1021. # This does a lot of work: it covers duck sizing and equality guards.
  1022. if not _simplified:
  1023. for source, expr in input_guards:
  1024. # Small optimization
  1025. if (
  1026. isinstance(expr, Symbol) and
  1027. expr in symbol_to_source and
  1028. source == symbol_to_source[expr][0]
  1029. ):
  1030. continue
  1031. sexpr = ShapeGuardPrinter(symbol_to_source, source_ref).doprint(expr)
  1032. exprs.append(f"{source_ref(source)} == {sexpr}")
  1033. # 2. Every guard must evaluate to True (but remember many guards
  1034. # like s0 == s1*2 because trivial due to simplification)
  1035. for g, tb in self.guards:
  1036. if self._maybe_evaluate_static(g) is not None:
  1037. continue
  1038. g = self.simplify(g)
  1039. try:
  1040. exprs.append(ShapeGuardPrinter(symbol_to_source, source_ref).doprint(g))
  1041. except Exception:
  1042. log.warning(f"Failing guard allocated at: \n{tb}")
  1043. raise
  1044. # 3. Every symbol must not be equal to 0/1
  1045. if not _simplified:
  1046. for sources in symbol_to_source.values():
  1047. assert sources
  1048. # We must assert that each symbol is not zero or one, as we make
  1049. # negative inferences on shape variables
  1050. exprs.append(f"{source_ref(sources[0])} != 0 and {source_ref(sources[0])} != 1")
  1051. return exprs
  1052. def evaluate_guards_for_args(self, placeholders, args):
  1053. from torch._dynamo.source import GlobalSource
  1054. arg_names = [f"t{i}" for i in range(len(args))]
  1055. guards = self.produce_guards(placeholders, [GlobalSource(a) for a in arg_names])
  1056. if guards:
  1057. code = " and ".join(guards)
  1058. return eval(code, {}, dict(zip(arg_names, args)))
  1059. return True
  1060. def bind_symbols(self, placeholders, args):
  1061. # Given a paired list of placeholders (fake tensors with
  1062. # symbolic sizes) and concrete arguments (regular tensors
  1063. # with real sizes), returns a dictionary mapping each
  1064. # symbol to its real value. So for example, if you
  1065. # have a placeholder with size (s0, s1), binding
  1066. # (2, 4) to it will give you {s0: 2, s1: 4}. This is
  1067. # not guaranteed to bind ALL symbols in the ShapeEnv;
  1068. # we can't bind a symbol if it doesn't occur in any placeholder,
  1069. # and symbols that already have replacements won't get bindings.
  1070. # This is a little duplicative with evaluate_guards but
  1071. # it's different enough that it seemed cleanest to make
  1072. # another copy. This assumes the guards are already checked,
  1073. # though if it's cheap we'll check for shenanigans
  1074. bindings: Dict[sympy.Symbol, int] = {}
  1075. def bind_symint(arg, val):
  1076. if isinstance(val, SymInt):
  1077. s = val.node.expr
  1078. if isinstance(s, sympy.Symbol):
  1079. if s in bindings:
  1080. assert bindings[s] == arg, f"{bindings[s]} != {arg}"
  1081. else:
  1082. bindings[s] = arg
  1083. elif isinstance(-s, sympy.Symbol):
  1084. if -s in bindings:
  1085. assert bindings[-s] == -arg, f"{bindings[-s]} != {-arg}"
  1086. else:
  1087. bindings[-s] = -arg
  1088. for t, arg in zip(placeholders, args):
  1089. if t is None:
  1090. continue
  1091. if isinstance(t, SymInt):
  1092. bind_symint(arg, t)
  1093. continue
  1094. assert isinstance(t, torch.Tensor)
  1095. for i, s in enumerate(t.size()):
  1096. bind_symint(arg.size(i), s)
  1097. for i, s in enumerate(t.stride()):
  1098. bind_symint(arg.stride(i), s)
  1099. bind_symint(arg.storage_offset(), t.storage_offset())
  1100. return bindings
  1101. def get_nontrivial_guards(self):
  1102. return [self.simplify(guard.expr) for guard in self.guards if self._maybe_evaluate_static(guard.expr) is None]
  1103. def format_guards(self, verbose=False):
  1104. def format_tb(tb):
  1105. if not verbose:
  1106. return ""
  1107. return f"\n Guarded at:\n{textwrap.indent(tb, ' ')}"
  1108. return '\n'.join(f" - {guard.expr}{format_tb(guard.stack)}" for guard in self.guards)
  1109. def get_shape_groups(self):
  1110. shape_groups = collections.defaultdict(list)
  1111. for k, v in self.replacements.items():
  1112. shape_groups[v].append(k)
  1113. return shape_groups
  1114. @_lru_cache
  1115. def _maybe_evaluate_static(self, expr: "sympy.Expr") -> "Optional[sympy.Expr]":
  1116. """
  1117. Tries to evaluate expr without introducing guards
  1118. """
  1119. expr = self.simplify(expr)
  1120. # Simplifies assuming that shape vars > 1 (since we cache on 0/1 shape values)
  1121. symbols = list(expr.free_symbols)
  1122. new_shape_env = {
  1123. k: sympy.Symbol(f"shape_{idx}", positive=True, integer=True) + 1
  1124. for idx, k in enumerate(symbols)
  1125. # Do not assume unbacked symints are > 1
  1126. if k in self.var_to_val
  1127. }
  1128. new_expr = expr.xreplace(new_shape_env)
  1129. floor_div_replace = {}
  1130. for atom in new_expr.atoms(FloorDiv):
  1131. floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1])
  1132. new_expr = safe_expand(new_expr.xreplace(floor_div_replace))
  1133. if len(list(new_expr.free_symbols)) == 0:
  1134. return new_expr
  1135. return None
  1136. @_lru_cache
  1137. def replace(self, expr: "sympy.Expr") -> "sympy.Expr":
  1138. replacements = {s: self._find(cast(sympy.Symbol, s)) for s in expr.free_symbols}
  1139. return safe_expand(expr.xreplace(replacements))
  1140. @_lru_cache
  1141. def _update_divisible(self):
  1142. new_divisible = set()
  1143. for k in self.divisible:
  1144. res = self.replace(k)
  1145. if len(res.free_symbols) > 0:
  1146. new_divisible.add(k)
  1147. self.divisible = new_divisible
  1148. @_lru_cache
  1149. def simplify(self, expr: "sympy.Expr") -> "sympy.Expr":
  1150. expr = self.replace(expr)
  1151. if expr.has(FloorDiv):
  1152. self._update_divisible()
  1153. div_replacements = {}
  1154. for atom in expr.atoms(FloorDiv):
  1155. base, divisor = atom.args
  1156. if self.replace(base % divisor) in self.divisible:
  1157. div_replacements[atom] = sympy.floor(base / divisor)
  1158. expr = expr.xreplace(div_replacements)
  1159. expr = safe_expand(expr)
  1160. return expr
  1161. @lru_cache(256)
  1162. def size_hint(self, expr: "sympy.Expr"):
  1163. """
  1164. Gets a size hint for a given expression from the underlying shapes we had.
  1165. Does not introduce a guard, so only use this when you can guarantee that
  1166. your code is still valid for arbitrary shapes (such as optimization decisions)
  1167. """
  1168. result_expr = safe_expand(expr).xreplace(self.var_to_val)
  1169. if len(result_expr.free_symbols) != 0:
  1170. raise self._make_data_dependent_error(result_expr)
  1171. return result_expr
  1172. def _make_data_dependent_error(self, expr):
  1173. # TODO: in a Dynamo context, having user code, and having the
  1174. # name of the local, will be much better
  1175. accesses = '\n\n'.join(
  1176. f"Data dependent variable '{s}' allocated at:\n{s.stack}"
  1177. for s in expr.free_symbols
  1178. )
  1179. return GuardOnDataDependentSymNode(
  1180. f"\n\n{accesses}\n"
  1181. "GuardOnDataDependentSymNode: It appears that you're trying to get "
  1182. "a value out of symbolic int/float "
  1183. "whose value is data-dependent (and thus we do not know the true value.) "
  1184. f"The expression we were trying to evaluate is {expr}. "
  1185. "Scroll up to see where each of these data-dependent accesses originally occurred."
  1186. # TODO: Help text about how to use our runtime tests to fix this
  1187. # problem
  1188. )
  1189. @_lru_cache
  1190. def _find(self, a: "sympy.Symbol") -> "sympy.Expr":
  1191. """
  1192. Implements a DSU-like algorithm to find the variable that represents a
  1193. Also handles transitive non-identity replacements.
  1194. a: b + c
  1195. c: d
  1196. """
  1197. if a not in self.replacements:
  1198. return a
  1199. res = self.replacements[a]
  1200. cur_replace = {s: self._find(s) for s in res.free_symbols}
  1201. self.replacements[a] = self.replacements[a].xreplace(cur_replace)
  1202. return self.replacements[a]
  1203. @lru_cache(256)
  1204. def _maybe_guard_eq(self, expr: Union["sympy.Eq", "sympy.Ne"], concrete_bool: bool) -> None:
  1205. """
  1206. Evaluates the result of an eq call. If true, uses information to
  1207. simplify shapes (i.e. a == b or a % 5 == 0)
  1208. """
  1209. assert type(concrete_bool) is bool
  1210. if isinstance(expr, sympy.Eq):
  1211. if not concrete_bool:
  1212. return
  1213. # NB: Apparently this is load bearing; to see what test fails if
  1214. # you comment it out run:
  1215. # python test/functorch/test_aotdispatch.py -k
  1216. # test_aot_autograd_symbolic_module_exhaustive_nn_LazyConv3d_cpu_float32
  1217. elif isinstance(expr, sympy.Ne):
  1218. if concrete_bool:
  1219. return
  1220. free = list(expr.free_symbols)
  1221. assert len(free) > 0, "The expression should not be static by this point"
  1222. # In case of really gnarly expression, we don't blow up
  1223. if len(free) > 5:
  1224. return
  1225. free = sorted(free, key=lambda x: (self.size_hint(x), x.name), reverse=True) # type: ignore[attr-defined]
  1226. lhs = expr.lhs
  1227. rhs = expr.rhs
  1228. if not expr.has(sympy.Mod):
  1229. try:
  1230. solutions = sympy.solve(lhs - rhs, free[0], dict=True)
  1231. if len(solutions) != 1:
  1232. return
  1233. solution = solutions[0][free[0]]
  1234. if all(t.is_integer for t in sympy.preorder_traversal(solution)):
  1235. new_var = self._find(solution)
  1236. self.replacements[cast(sympy.Symbol, free[0])] = new_var
  1237. except NotImplementedError:
  1238. pass
  1239. except RecursionError:
  1240. log.warning(f"RecursionError in sympy.solve({lhs} - {rhs}, {free[0]})")
  1241. if expr.has(sympy.Mod):
  1242. mod_expr = tuple(expr.atoms(sympy.Mod))[0]
  1243. try:
  1244. solutions = sympy.solve(lhs - rhs, mod_expr, dict=True)
  1245. if len(solutions) == 1 and solutions[0][mod_expr] == 0:
  1246. self.divisible.add(mod_expr)
  1247. except NotImplementedError:
  1248. pass
  1249. return
  1250. @lru_cache(256)
  1251. def evaluate_expr(self, expr: "sympy.Expr", hint=None):
  1252. """
  1253. Given an expression, evaluates it, adding guards if necessary
  1254. """
  1255. if len(expr.free_symbols) == 0:
  1256. return expr
  1257. expr = self.simplify(expr)
  1258. static_expr = self._maybe_evaluate_static(expr)
  1259. if static_expr is not None:
  1260. return static_expr
  1261. if hint is None:
  1262. concrete_val = self.size_hint(expr)
  1263. else:
  1264. concrete_val = sympy.sympify(hint)
  1265. if isinstance(expr, (sympy.Eq, sympy.Ne)):
  1266. self._maybe_guard_eq(expr, bool(concrete_val))
  1267. # TODO: If we successfully eliminate a symbol via equality, it
  1268. # is not actually necessary to save a guard for the equality,
  1269. # as we will implicitly generate a guard when we match that
  1270. # input against the symbol
  1271. # TODO: optimize this; avoid formatting traces until we need them
  1272. # NB: drop two frames; evaluate_expr and the Sym* function that
  1273. # actually called us
  1274. if not self._suppress_guards_tls():
  1275. stack = ''.join(traceback.format_list(traceback.extract_stack()[:-2]))
  1276. if concrete_val is sympy.true:
  1277. self.guards.append(ShapeGuard(expr, stack))
  1278. elif concrete_val is sympy.false:
  1279. self.guards.append(ShapeGuard(sympy.Not(expr), stack))
  1280. else:
  1281. self.guards.append(
  1282. ShapeGuard(sympy.Eq(expr, concrete_val), stack)) # type: ignore[arg-type]
  1283. return concrete_val