utils.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. import collections
  2. import contextlib
  3. import functools
  4. import glob
  5. import itertools
  6. import logging
  7. import math
  8. import operator
  9. import os
  10. import shutil
  11. import tempfile
  12. import textwrap
  13. import time
  14. from io import StringIO
  15. from typing import Any, Dict, List, Optional, Union
  16. from unittest import mock
  17. import sympy
  18. import torch
  19. from torch.fx.immutable_collections import immutable_dict, immutable_list
  20. from . import config, config as inductor_config
  21. from .cuda_properties import get_device_capability
  22. log = logging.getLogger(__name__)
  23. VarRanges = Dict[sympy.Expr, sympy.Expr]
  24. try:
  25. from triton.testing import do_bench
  26. except ImportError:
  27. def do_bench(*args, **kwargs):
  28. raise NotImplementedError("requires Triton")
  29. @functools.lru_cache(None)
  30. def has_triton():
  31. if not torch.cuda.is_available():
  32. return False
  33. try:
  34. import triton
  35. return triton is not None and get_device_capability() >= (7, 0)
  36. except ImportError:
  37. return False
  38. @functools.lru_cache(None)
  39. def has_torchvision_roi_align():
  40. try:
  41. from torchvision.ops import roi_align # noqa: F401
  42. return roi_align is not None and hasattr(
  43. getattr(torch.ops, "torchvision", None), "roi_align"
  44. )
  45. except ImportError:
  46. return False
  47. def conditional_product(*args):
  48. return functools.reduce(operator.mul, [x for x in args if x])
  49. def sympy_product(it):
  50. return functools.reduce(operator.mul, it, sympy.Integer(1))
  51. def sympy_dot(seq1, seq2):
  52. assert len(seq1) == len(seq2)
  53. return sympy.expand(sum(a * b for a, b in zip(seq1, seq2)))
  54. def unique(it):
  55. return {id(x): x for x in it}.values()
  56. def ceildiv(numer: int, denom: int):
  57. assert isinstance(numer, int) and isinstance(denom, int)
  58. return -(numer // -denom)
  59. def convert_shape_to_inductor(lst: List[Union[int, torch.SymInt]]) -> List[sympy.Expr]:
  60. """
  61. Gets the shape and stride of a tensor. For non-symbolic tensors, this is
  62. trivial. But for symbolic tensors, we need to map from SymIntNode into
  63. sympy.Expr.
  64. """
  65. return [
  66. i.node.expr if isinstance(i, torch.SymInt) else sympy.Integer(i) for i in lst
  67. ]
  68. def convert_shape_to_symint(
  69. lst: List[Union[int, sympy.Expr]]
  70. ) -> List[Union[int, torch.SymInt]]:
  71. """
  72. Takes a list of shapes from Inductor and converts them into symints (or just
  73. ints if all shapes are static).
  74. """
  75. from .virtualized import V
  76. return [
  77. i
  78. if isinstance(i, int)
  79. else int(i)
  80. if isinstance(i, sympy.Integer)
  81. else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
  82. for i in lst
  83. ]
  84. def gen_gm_and_inputs(target, args, kwargs):
  85. g = torch.fx.Graph()
  86. g_args = []
  87. a_args = []
  88. for n, arg in enumerate(args):
  89. if isinstance(arg, torch.Tensor):
  90. g_args.append(g.placeholder(f"arg{n}"))
  91. a_args.append(arg)
  92. else:
  93. g_args.append(arg)
  94. assert all(not isinstance(x, torch.Tensor) for x in kwargs.values())
  95. node = g.call_function(target, tuple(g_args), kwargs)
  96. if (
  97. len(target._schema.returns) == 1
  98. and str(target._schema.returns[0].type) == "Tensor"
  99. ):
  100. node = (node,)
  101. g.output(node)
  102. gm = torch.fx.GraphModule({}, g)
  103. return gm, a_args
  104. def synchronize():
  105. if torch.cuda.is_available():
  106. torch.cuda.synchronize()
  107. def timed(model, example_inputs, times=1):
  108. synchronize()
  109. torch.manual_seed(1337)
  110. t0 = time.perf_counter()
  111. for _ in range(times):
  112. result = model(*example_inputs)
  113. synchronize()
  114. t1 = time.perf_counter()
  115. # GC the result after timing
  116. assert result is not None
  117. return t1 - t0
  118. def print_performance(fn, args=(), times=10, repeat=10, baseline=1.0):
  119. timings = torch.tensor([timed(fn, args, times) for _ in range(repeat)])
  120. took = torch.median(timings)
  121. print(f"{took/baseline:.6f}")
  122. return took
  123. immutable_dict.__hash__ = lambda self: hash(tuple(self.items()))
  124. immutable_list.__hash__ = lambda self: hash(tuple(self))
  125. def freeze_inputs(f):
  126. """
  127. Useful for wrapping lists in tuples for caching purposes
  128. """
  129. def freeze_value(x):
  130. if isinstance(x, (immutable_dict, immutable_list)):
  131. return x
  132. if isinstance(x, list):
  133. return immutable_list(x)
  134. if isinstance(x, dict):
  135. return immutable_dict(x)
  136. return x
  137. @functools.wraps(f)
  138. def wrapped(*args):
  139. args = [freeze_value(x) for x in args]
  140. return f(*args)
  141. wrapped.cache_info = f.cache_info
  142. return wrapped
  143. def precompute_method(obj: Any, method: str):
  144. """Replace obj.method() with a new method that returns a precomputed constant."""
  145. result = getattr(obj, method)()
  146. setattr(obj, method, lambda: result)
  147. def precompute_methods(obj: Any, methods: List[str]):
  148. """Replace methods with new methods that returns a precomputed constants."""
  149. for method in methods:
  150. precompute_method(obj, method)
  151. def cmp(a, b):
  152. return int(a > b) - int(a < b)
  153. def cache_on_self(fn):
  154. key = f"__{fn.__name__}_cache"
  155. @functools.wraps(fn)
  156. def wrapper(self):
  157. if not hasattr(self, key):
  158. setattr(self, key, fn(self))
  159. return getattr(self, key)
  160. return wrapper
  161. def get_fused_kernel_name(node_schedule):
  162. return "_".join(
  163. ["fused"]
  164. + sorted(
  165. [
  166. str(origin.name)
  167. for origin in functools.reduce(
  168. operator.or_,
  169. [
  170. node.node.origins
  171. for node in node_schedule
  172. if hasattr(node, "node")
  173. ],
  174. )
  175. if origin.op == "call_function"
  176. ]
  177. )[0 : config.kernel_name_max_ops]
  178. )
  179. def gather_origins(args, kwargs):
  180. import itertools
  181. from .ir import ComputedBuffer, IRNode
  182. def is_unrealized_node(n):
  183. return isinstance(n, IRNode) and not isinstance(n, ComputedBuffer)
  184. kwarg_origins = [val.origins for val in kwargs.values() if is_unrealized_node(val)]
  185. arg_origins = [arg.origins for arg in args if is_unrealized_node(arg)]
  186. return set(itertools.chain(*arg_origins, *kwarg_origins))
  187. def sympy_str(expr: sympy.Expr):
  188. """
  189. Normal sympy str is very slow, this is a lot faster. The result are
  190. somewhat worse, as it doesn't do as much simplification. So don't
  191. use this for final codegen.
  192. """
  193. if isinstance(expr, sympy.Symbol):
  194. return expr.name
  195. if isinstance(expr, sympy.Add):
  196. return " + ".join(map(sympy_str, expr.args))
  197. if isinstance(expr, sympy.Mul):
  198. return " * ".join(map(sympy_str, expr.args))
  199. from .ir import CleanDiv, FloorDiv, ModularIndexing
  200. if isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv)):
  201. return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})"
  202. return str(expr)
  203. def sympy_symbol(name):
  204. # This should never be used for creating shape/stride symbols, as those
  205. # should all be allocated before Inductor.
  206. assert name[0] != "s"
  207. return sympy.Symbol(name, integer=True, positive=True)
  208. def sympy_subs(expr: sympy.Expr, replacements: Dict[Any, Any]):
  209. """
  210. xreplace is faster than subs, but is way more picky
  211. """
  212. def promote_strings(key):
  213. if isinstance(key, str):
  214. return sympy_symbol(key)
  215. return key
  216. return expr.xreplace(
  217. {promote_strings(k): promote_strings(v) for k, v in replacements.items()}
  218. )
  219. def free_symbol_startswith(index: sympy.Expr, prefix: str):
  220. return any(v.name.startswith(prefix) for v in index.free_symbols)
  221. def has_incompatible_cudagraph_ops(gm):
  222. forbidden_list = {
  223. "aten._fused_moving_avg_obs_fq_helper.default",
  224. "aten._fused_moving_avg_obs_fq_helper_functional.default",
  225. "fbgemm.dense_to_jagged.default",
  226. "fbgemm.jagged_to_padded_dense.default",
  227. }
  228. for node in gm.graph.nodes:
  229. if str(node.target) in forbidden_list:
  230. return True
  231. return False
  232. instance_descriptor = collections.namedtuple(
  233. "instance_descriptor", ["divisible_by_16", "equal_to_1"]
  234. )
  235. @contextlib.contextmanager
  236. def fresh_inductor_cache(cache_entries=None):
  237. """
  238. Contextmanager that provides a clean tmp cachedir for inductor.
  239. Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes
  240. generated with this cache instance.
  241. """
  242. with tempfile.TemporaryDirectory() as inductor_cache_dir:
  243. with mock.patch.dict(
  244. os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir}
  245. ):
  246. triton_cache_dir = os.path.join(inductor_cache_dir, "triton")
  247. with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}):
  248. yield
  249. if isinstance(cache_entries, dict):
  250. assert len(cache_entries) == 0, "expected empty cache_entries dict"
  251. if os.path.exists(triton_cache_dir):
  252. files = os.listdir(triton_cache_dir)
  253. cache_entries.update(
  254. {
  255. f: os.path.getsize(os.path.join(triton_cache_dir, f))
  256. for f in files
  257. if ".lock" not in f
  258. }
  259. )
  260. def argsort(seq):
  261. # preserve original order for equal strides
  262. getter = seq.__getitem__
  263. a_r = range(len(seq))
  264. return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413
  265. @functools.lru_cache(8)
  266. def get_dtype_size(dtype):
  267. return torch.empty((), dtype=dtype).element_size()
  268. class IndentedBuffer:
  269. tabwidth = 4
  270. def __init__(self, initial_indent=0):
  271. self._lines = []
  272. self._indent = initial_indent
  273. def getvalue(
  274. self,
  275. ):
  276. buf = StringIO()
  277. for line in self._lines:
  278. if isinstance(line, DeferredLineBase):
  279. line = line()
  280. if line is None:
  281. continue
  282. assert isinstance(line, str)
  283. buf.write(line)
  284. buf.write("\n")
  285. return buf.getvalue()
  286. def getrawvalue(self):
  287. buf = StringIO()
  288. for line in self._lines:
  289. if isinstance(line, DeferredLineBase):
  290. line = line()
  291. if line is None:
  292. continue
  293. assert isinstance(line, str)
  294. # backslash implies line continuation
  295. if line.endswith("\\"):
  296. buf.write(line[:-1])
  297. else:
  298. buf.write(line)
  299. buf.write("\n")
  300. return buf.getvalue()
  301. def clear(self):
  302. self._lines.clear()
  303. def __bool__(self):
  304. return bool(self._lines)
  305. def prefix(self):
  306. return " " * (self._indent * self.tabwidth)
  307. def writeline(self, line):
  308. if isinstance(line, DeferredLineBase):
  309. self._lines.append(line.with_prefix(self.prefix()))
  310. elif line.strip():
  311. self._lines.append(f"{self.prefix()}{line}")
  312. else:
  313. self._lines.append("")
  314. def writelines(self, lines):
  315. for line in lines:
  316. self.writeline(line)
  317. def indent(self, offset=1):
  318. @contextlib.contextmanager
  319. def ctx():
  320. self._indent += offset
  321. yield
  322. self._indent -= offset
  323. return ctx()
  324. def splice(self, other_code, strip=False):
  325. if isinstance(other_code, IndentedBuffer):
  326. dedent = float("inf")
  327. for line in other_code._lines:
  328. if line:
  329. dedent = min(dedent, len(line) - len(line.lstrip()))
  330. if math.isinf(dedent):
  331. dedent = 0
  332. for line in other_code._lines:
  333. IndentedBuffer.writeline(self, line[dedent:])
  334. else:
  335. other_code = textwrap.dedent(other_code)
  336. if strip:
  337. other_code = other_code.lstrip()
  338. if not other_code:
  339. return
  340. other_code = other_code.rstrip()
  341. for line in other_code.split("\n"):
  342. self.writeline(line)
  343. class DeferredLineBase:
  344. """A line that can be 'unwritten' at a later time"""
  345. def __init__(self, line):
  346. if not line.strip():
  347. line = ""
  348. self.line = line
  349. def __call__(self) -> Optional[str]:
  350. """Returns either self.line or None to indicate the line has been 'unwritten'"""
  351. raise NotImplementedError()
  352. def _new_line(self, line: str) -> "DeferredLineBase":
  353. """Returns a new deferred line with the same condition"""
  354. raise NotImplementedError()
  355. def with_prefix(self, prefix):
  356. return self._new_line(f"{prefix}{self.line}")
  357. def lstrip(self):
  358. return self._new_line(self.line.lstrip())
  359. def __getitem__(self, index):
  360. return self._new_line(self.line[index])
  361. def __bool__(self):
  362. return bool(self.line)
  363. def __len__(self):
  364. return len(self.line)
  365. @functools.lru_cache(None)
  366. def is_big_gpu(index):
  367. cores = torch.cuda.get_device_properties(index).multi_processor_count
  368. if cores < 80: # V100
  369. log.warning("not enough cuda cores to use max_autotune mode")
  370. return False
  371. return True
  372. def use_triton_template(layout):
  373. return (
  374. inductor_config.max_autotune
  375. and layout.device.type == "cuda"
  376. and layout.dtype in (torch.float16, torch.bfloat16, torch.float32)
  377. and is_big_gpu(layout.device.index or 0)
  378. )
  379. class DebugDirManager:
  380. counter = itertools.count(0)
  381. def __init__(self):
  382. self.id = next(DebugDirManager.counter)
  383. self.prev_debug_name = None
  384. def __enter__(self):
  385. self.prev_debug_name = torch._dynamo.config.debug_dir_root
  386. self.new_name = f"{self.prev_debug_name}_tmp_{self.id}"
  387. torch._dynamo.config.debug_dir_root = self.new_name
  388. def __exit__(self, *args):
  389. shutil.rmtree(self.new_name)
  390. torch._dynamo.config.debug_dir_root = self.prev_debug_name
  391. def run_and_get_triton_code(fn, *args, **kwargs):
  392. from torch._inductor.debug import DebugContext
  393. from torch._inductor.virtualized import V
  394. torch._dynamo.reset()
  395. context = DebugContext()
  396. with DebugDirManager(), mock.patch.object(
  397. config.trace, "enabled", True
  398. ), context, V.set_debug_handler(context):
  399. dir_name = "/".join(context._path.split("/")[:-1]) + "/"
  400. fil = dir_name + "*inference*"
  401. existing_dirs = glob.glob(fil)
  402. fn(*args, **kwargs)
  403. assert context._path is not None
  404. dir_dbg = [x for x in glob.glob(fil) if x not in existing_dirs]
  405. assert len(dir_dbg) == 1, f"{dir_dbg}, {context._path}"
  406. full_name = os.path.join(dir_dbg[0], "output_code.py")
  407. with open(full_name, "r") as f:
  408. return f.read()
  409. def developer_warning(msg):
  410. """
  411. Warnings that will be actionable for PyTorch developers, but not
  412. end users. Allows us to easily disable them in stable releases but
  413. keep them on for nightly builds.
  414. """
  415. if config.developer_warnings:
  416. log.warning(msg)
  417. else:
  418. log.info(msg)