compilers.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. import copy
  2. import logging
  3. import os
  4. import pickle
  5. import random
  6. from contextlib import contextmanager
  7. from functools import partial
  8. from typing import Callable, Optional, Tuple, Union
  9. import torch
  10. from torch import SymInt
  11. import torch.fx as fx
  12. import torch.nn as nn
  13. from torch._decomp import get_decompositions
  14. from torch.fx.experimental.symbolic_shapes import bind_symbols
  15. from .aot_autograd import aot_function, aot_module, make_boxed_compiler
  16. from .compile_utils import strip_overloads
  17. from .partitioners import (
  18. default_partition,
  19. draw_graph,
  20. min_cut_rematerialization_partition,
  21. )
  22. import torch.utils._pytree as pytree
  23. log = logging.getLogger(__name__)
  24. # These canonicalizations are needed here (and not decompositions), as the ops
  25. # we're trying to canonicalize to CompositeImplicitAutograd.
  26. def _canonicalize(fx_g):
  27. for node in fx_g.graph.nodes:
  28. if node.target == torch.ops.aten._to_copy:
  29. node.target = torch.ops.aten.to
  30. fx_g.recompile()
  31. return fx_g
  32. @contextmanager
  33. def _disable_jit_autocast():
  34. old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
  35. try:
  36. yield
  37. finally:
  38. torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
  39. @make_boxed_compiler
  40. def ts_compile(fx_g: fx.GraphModule, inps) -> Callable:
  41. """
  42. Compiles the :attr:`fx_g` with Torchscript compiler.
  43. .. warning::
  44. This API is experimental and likely to change.
  45. Args:
  46. fx_g(fx.GraphModule): The input Fx graph module to be compiled.
  47. Returns:
  48. Torch scripted model.
  49. """
  50. with _disable_jit_autocast():
  51. strip_overloads(fx_g)
  52. for node in fx_g.graph.nodes:
  53. if (
  54. node.target == torch.ops.aten._to_copy
  55. and len(node.args) == 1
  56. and len(node.kwargs) == 1
  57. and "dtype" in node.kwargs
  58. ):
  59. node.target = torch.ops.aten.to
  60. for node in fx_g.graph.nodes:
  61. new_kwargs = {}
  62. for k, v in node.kwargs.items():
  63. if isinstance(v, torch.device):
  64. v = v.type
  65. new_kwargs[k] = v
  66. node.kwargs = new_kwargs
  67. fx_g.graph.lint()
  68. fx_g.recompile()
  69. f = torch.jit.script(fx_g)
  70. torch._C._jit_pass_remove_mutation(f.graph)
  71. f = torch.jit.freeze(f.eval())
  72. f = torch.jit.optimize_for_inference(f)
  73. if not any(isinstance(t, torch._subclasses.FakeTensor) for t in inps):
  74. f(*inps)
  75. return f
  76. def _draw_graph_compile(fx_g, _, name, clear_meta=True):
  77. print(fx_g.code)
  78. draw_graph(fx_g, name, clear_meta=clear_meta)
  79. return fx_g
  80. def draw_graph_compile(name):
  81. return make_boxed_compiler(
  82. partial(_draw_graph_compile, name=name)
  83. )
  84. @make_boxed_compiler
  85. def nop(fx_g: fx.GraphModule, _) -> Callable:
  86. """
  87. Returns the :attr:`fx_g` Fx graph module as it is. This is a no-op compiler
  88. and can be used to check accuracy.
  89. .. warning::
  90. This API is experimental and likely to change.
  91. """
  92. return fx_g
  93. class DebugInterpreter(fx.Interpreter):
  94. def run(self, *args):
  95. self.symbol_mapping = bind_symbols(self.module, *args)
  96. super().run(*args)
  97. def run_node(self, n):
  98. import sympy
  99. def subst_symint(ni):
  100. if not isinstance(ni, SymInt):
  101. return ni
  102. r = sympy.expand(ni.node.expr.xreplace(self.symbol_mapping))
  103. assert len(r.free_symbols) == 0, r
  104. return int(r)
  105. def subst_symint_tuple(nis):
  106. return tuple(subst_symint(ni) for ni in nis)
  107. def check_significant_strides(a, b):
  108. if subst_symint(a.numel()) > 0:
  109. for idx in range(a.ndim):
  110. if subst_symint(a.stride(idx)) != b.stride(idx) and subst_symint(a.size(idx)) > 1:
  111. return False
  112. return True
  113. def check(nv, rv, desc):
  114. assert callable(desc)
  115. assert nv.dtype == rv.dtype, f"{desc()}: {nv.dtype} != {rv.dtype}"
  116. assert subst_symint_tuple(nv.size()) == rv.size(), \
  117. f"{desc()}: {nv.size()} aka {subst_symint_tuple(nv.size())} != {rv.size()}"
  118. same_strides = check_significant_strides(nv, rv)
  119. assert same_strides, f"{desc()}: {nv.stride()} aka {subst_symint_tuple(nv.stride())} != {rv.stride()}"
  120. r = super().run_node(n)
  121. if 'val' in n.meta:
  122. n_vals, n_spec = pytree.tree_flatten(n.meta['val'])
  123. r_vals, r_spec = pytree.tree_flatten(r)
  124. # TODO: There is some sort of problem where we record that an
  125. # operator returned a tuple/list, and then later it turns out the
  126. # real version of the operator returned a list/tuple. Need to
  127. # figure out what's actually going on here, the error itself is
  128. # harmless enough as we only getitem out the outputs.
  129. # assert n_spec == r_spec, f"{n_spec} != {r_spec}"
  130. assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}"
  131. for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
  132. if not isinstance(rv, torch.Tensor):
  133. continue
  134. check(nv, rv, lambda: f"output {i} where {self.symbol_mapping}")
  135. return r
  136. @make_boxed_compiler
  137. def debug_nop(fx_g: fx.GraphModule, _) -> Callable:
  138. """
  139. Returns a (slow) interpreter over the FX graph module that also checks
  140. various debugging properties (e.g., that tracing strides matched real
  141. strides.)
  142. """
  143. return DebugInterpreter(fx_g).run
  144. @make_boxed_compiler
  145. def simple_ts_compile(fx_g, _):
  146. strip_overloads(fx_g)
  147. f = torch.jit.script(fx_g)
  148. f = torch.jit.freeze(f.eval())
  149. return f
  150. def nnc_jit(f, static_argnums=None):
  151. return aot_function(f, simple_ts_compile, static_argnums=static_argnums)
  152. aten = torch.ops.aten
  153. default_decompositions = {
  154. aten.detach,
  155. aten.gelu_backward,
  156. aten.leaky_relu_backward,
  157. aten.sigmoid_backward,
  158. aten.threshold_backward,
  159. aten.hardtanh_backward,
  160. aten.hardsigmoid_backward,
  161. aten.hardswish_backward,
  162. aten.tanh_backward,
  163. aten.silu_backward,
  164. aten.elu_backward,
  165. aten.cudnn_batch_norm,
  166. aten.cudnn_batch_norm_backward,
  167. aten.masked_fill.Scalar,
  168. aten.masked_fill.Tensor,
  169. aten.elu,
  170. aten.leaky_relu,
  171. aten.hardtanh,
  172. aten.hardswish,
  173. aten.hardsigmoid,
  174. aten.conj_physical,
  175. aten.is_same_size,
  176. }
  177. default_decompositions = get_decompositions(default_decompositions)
  178. @make_boxed_compiler
  179. def print_compile(fx_g, _):
  180. print(fx_g.code)
  181. return fx_g
  182. def memory_efficient_fusion(
  183. fn: Union[Callable, nn.Module],
  184. static_argnums: Optional[Tuple[int]] = None,
  185. **kwargs,
  186. ):
  187. """
  188. Wrapper function over :func:`aot_function` and :func:`aot_module` to perform
  189. memory efficient fusion. It uses the
  190. :func:`min_cut_rematerialization_partition` partitioner to perform efficient
  191. recomputation. It uses NVFuser to compile the generated forward and backward
  192. graphs.
  193. .. warning::
  194. This API is experimental and likely to change.
  195. Args:
  196. fn (Union[Callable, nn.Module]): A Python function or a ``nn.Module``
  197. that takes one ore more arguments. Must return one or more Tensors.
  198. static_argnums (Optional[Tuple[Int]]): An option tuple of ints to mark
  199. the arguments of the function as static.
  200. **kwargs: Any other overrides you want to make to the settings
  201. Returns:
  202. Returns a ``Callable`` or ``nn.Module`` that retains the eager behavior
  203. of the original :attr:`fn`, but whose forward and backward graphs have
  204. gone through recomputation optimizations, and the graphs have been
  205. compiled with nvfuser.
  206. """
  207. config = {
  208. "fw_compiler": ts_compile,
  209. "bw_compiler": ts_compile,
  210. "partition_fn": min_cut_rematerialization_partition,
  211. "decompositions": default_decompositions,
  212. "static_argnums": static_argnums,
  213. }
  214. config.update(kwargs)
  215. if isinstance(fn, torch.nn.Module):
  216. return aot_module(fn, **config)
  217. else:
  218. return aot_function(fn, **config)
  219. def debug_compile(fx_g, inps):
  220. fx_g.to_folder("foo")
  221. print(
  222. f"""
  223. ##############################################################
  224. # To minimize FX graph, copy and paste the below and run it #
  225. ##############################################################
  226. import torch
  227. import torch.fx as fx
  228. from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess
  229. inps = {[(i.shape, i.dtype) for i in inps]}
  230. inps = [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]
  231. from foo import FxModule
  232. mod = FxModule().cuda()
  233. with torch.jit.fuser("fuser2"):
  234. # check_nvfuser_subprocess can be replaced with check_nvfuser_correctness_subprocess
  235. minifier(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess)
  236. """
  237. )
  238. from foo import FxModule
  239. FxModule().cuda()(*inps)
  240. return ts_compile(fx_g, inps)
  241. graph_index = 0
  242. def get_inputs(input_data_path):
  243. """
  244. Return a random input for the given inputs meta generated from _save_fx_default.
  245. """
  246. inputs = []
  247. with (open(input_data_path, "rb")) as f:
  248. inputs_meta = pickle.load(f)
  249. inputs = []
  250. for meta in inputs_meta:
  251. if len(meta) == 1:
  252. type = meta
  253. input = type(random.rand())
  254. else:
  255. type, shape, stride, dtype, device = meta
  256. if dtype in {
  257. torch.int,
  258. torch.int32,
  259. torch.int64,
  260. torch.bool,
  261. torch.int,
  262. torch.uint8,
  263. int,
  264. float,
  265. }:
  266. input = torch.randint(0, 1, shape, dtype=dtype, device=device)
  267. else:
  268. input = torch.rand(shape, dtype=dtype, device=device)
  269. inputs.append(input)
  270. return inputs
  271. def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_inputs):
  272. """
  273. The forward, backward, and joint computation graph will be stored in
  274. {folder_name}/{current_name}/{current_name}_forward_{graph_index},
  275. {folder_name}/{current_name}/{current_name}_backward_{graph_index}, and
  276. {folder_name}/{current_name}/{current_name}_joint_{graph_index} respectively.
  277. The input shape of the graphs will be stored in the .input files.
  278. These files can be loaded with pickle,
  279. and is a list of format (type, shape, stride, dtype, device).
  280. In the case of type = int or float, it is just (type,).
  281. For joint graph input, it is a nested list [[],[]]
  282. where the two inner lists have the same format.
  283. If dump_example_input is True, example_inputs will be stored in .pt file.
  284. Since each function might produce multiple graphs,
  285. the graph_index is used to distinguish difference graphs
  286. """
  287. from functorch.compile import aot_module_simplified
  288. def get_input_meta(args):
  289. input_meta = []
  290. if len(args) > 0 and isinstance(args[0], tuple): # joint input
  291. input_meta += get_input_meta(args[0])
  292. input_meta += get_input_meta(args[1])
  293. return input_meta
  294. for arg in args:
  295. if type(arg) == int or type(arg) == float:
  296. input_meta.append((type(arg),))
  297. else:
  298. input_meta.append(
  299. (type(arg), arg.shape, arg.stride(), arg.dtype, arg.device)
  300. )
  301. return input_meta
  302. def graph_saver_helper(gm_to_save, args, type_name):
  303. global graph_index
  304. if len(gm_to_save.graph.nodes) == 0:
  305. log.log(
  306. logging.WARNING,
  307. f"No nodes in graph {current_name}_{type_name}_{graph_index}.",
  308. )
  309. return
  310. gm = copy.deepcopy(gm_to_save)
  311. gm.graph.set_codegen(torch.fx.graph.CodeGen()) # remove codegen
  312. gm.recompile()
  313. input_meta = get_input_meta(args)
  314. isExist = os.path.exists(f"{folder_name}/{current_name}")
  315. if not isExist:
  316. os.makedirs(f"{folder_name}/{current_name}")
  317. gm.to_folder(
  318. f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}"
  319. )
  320. pickle.dump(
  321. input_meta,
  322. open(
  323. f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.input", # noqa: B950
  324. "wb",
  325. ),
  326. ) # noqa: E501
  327. if dump_example_input:
  328. torch.save(
  329. args,
  330. f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.pt", # noqa: B950
  331. ) # noqa: E501
  332. def graph_saver_forward(gm, fw_args):
  333. graph_saver_helper(gm, fw_args, "forward")
  334. return gm
  335. def graph_saver_backward(gm, bw_args):
  336. graph_saver_helper(gm, bw_args, "backward")
  337. global graph_index
  338. graph_index += 1
  339. return gm
  340. def graph_saver_joint(gm, joint_args):
  341. graph_saver_helper(gm, joint_args, "joint")
  342. return default_partition(gm, joint_args)
  343. return aot_module_simplified(
  344. gm,
  345. example_inputs,
  346. fw_compiler=graph_saver_forward,
  347. bw_compiler=graph_saver_backward,
  348. partition_fn=graph_saver_joint,
  349. decompositions=default_decompositions,
  350. )
  351. # WARNING: This isn't tested anywhere!!
  352. def graph_dumper_aot(current_name, folder_name, dump_example_input=False):
  353. """
  354. Dump the forward, backward, and joint computation graph.
  355. Example Usage:
  356. save_fx_func = graph_dumper_aot(current_name, folder_name, dump_example_input = False)
  357. optimize_ctx = torchdynamo.optimize(
  358. save_fx_func
  359. )
  360. with torch.enable_grad():
  361. with optimize_ctx:
  362. result = forward_and_backward_pass(model, example_inputs)
  363. """
  364. global graph_index
  365. graph_index = 0
  366. return partial(_save_fx_default, current_name, folder_name, dump_example_input)