compile_fx.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. import dataclasses
  2. import functools
  3. import itertools
  4. import logging
  5. import sys
  6. import warnings
  7. from typing import Any, Dict, List, Optional
  8. import functorch
  9. from functorch.compile import min_cut_rematerialization_partition
  10. import torch._dynamo.config as dynamo_config
  11. import torch.fx
  12. from torch._dynamo import logging as dynamo_logging, utils as dynamo_utils
  13. from torch._dynamo.utils import fake_mode_from_tensors
  14. from torch._functorch.aot_autograd import make_boxed_func
  15. from torch._subclasses.fake_tensor import FakeTensor
  16. from .._dynamo.backends.common import aot_autograd
  17. from . import config, metrics, overrides, pattern_matcher
  18. from .debug import DebugContext
  19. from .decomposition import select_decomp_table
  20. from .graph import GraphLowering
  21. from .mkldnn import convert_outplace_to_inplace
  22. from .utils import developer_warning, get_dtype_size, has_incompatible_cudagraph_ops
  23. from .virtualized import V
  24. log = logging.getLogger(__name__)
  25. ALIGNMENT = 16
  26. @dataclasses.dataclass
  27. class BoxedBool:
  28. value: bool
  29. def __bool__(self):
  30. return self.value
  31. @staticmethod
  32. def disable(obj):
  33. if isinstance(obj, BoxedBool):
  34. obj.value = False
  35. return obj
  36. return False
  37. # copy_ fails when trying to write to tensors with memory overlap,
  38. # for expanded dimensions (a dimension which used to have size 1 -> ?)
  39. # we can select one element from that dimension and write to it
  40. # to achieve writing to all values of that dimension of the input tensor
  41. def get_expanded_dims(t):
  42. return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1]
  43. def index_expanded_dims(t, expanded_dims):
  44. for expanded_dim in expanded_dims:
  45. t = torch.ops.aten.slice(t, expanded_dim, 0, 1)
  46. return t
  47. def complex_memory_overlap(t):
  48. # if torch._debug_has_internal_overlap thinks this tensor potentially has
  49. # memory overlap internally, let's dig deeper to find out whether it's true.
  50. if torch._debug_has_internal_overlap(t) != 0:
  51. strides = t.stride()
  52. sizes = t.shape
  53. indices = list(range(len(strides)))
  54. indices = [x for _, x in sorted(zip(strides, indices))]
  55. for i in range(len(strides)):
  56. prev_stride = 1 if i == 0 else strides[indices[i - 1]]
  57. prev_size = 1 if i == 0 else sizes[indices[i - 1]]
  58. if strides[indices[i]] < prev_stride * prev_size:
  59. return True
  60. return False
  61. @functools.lru_cache(None)
  62. def _step_logger():
  63. return dynamo_logging.get_step_logger(log)
  64. @functools.lru_cache(None)
  65. def _warn_tf32_disabled():
  66. if (
  67. torch.cuda.is_available()
  68. and not torch.backends.cuda.matmul.allow_tf32
  69. and torch.cuda.get_device_capability() >= (8, 0)
  70. ):
  71. warnings.warn(
  72. "TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. "
  73. "Consider setting `torch.set_float32_matmul_precision('high')` for better performance."
  74. )
  75. def is_tf32_warning_applicable(gm: torch.fx.GraphModule):
  76. aten = torch.ops.aten
  77. tf32_ops = {
  78. aten.mm.default,
  79. aten.addmm.default,
  80. aten.bmm.default,
  81. aten.baddbmm.default,
  82. }
  83. for node in gm.graph.nodes:
  84. if (
  85. node.op == "call_function"
  86. and node.target in tf32_ops
  87. and isinstance(node.meta.get("val", None), torch.Tensor)
  88. and node.meta["val"].dtype == torch.float32
  89. and node.meta["val"].device.type == "cuda"
  90. ):
  91. return True
  92. return False
  93. @DebugContext.wrap
  94. def count_bytes_inner(gm, example_inputs, num_fixed=0, **kwargs):
  95. shape_env = _shape_env_from_inputs(example_inputs)
  96. graph = GraphLowering(gm, shape_env=shape_env, num_static_inputs=num_fixed)
  97. with V.set_graph_handler(graph):
  98. graph.run(*example_inputs)
  99. num_bytes, nodes_num_elem = graph.count_bytes()
  100. metrics.num_bytes_accessed += num_bytes
  101. metrics.nodes_num_elem += nodes_num_elem
  102. return make_boxed_func(gm.forward)
  103. @DebugContext.wrap
  104. @torch.utils._python_dispatch._disable_current_modes()
  105. def compile_fx_inner(
  106. gm: torch.fx.GraphModule,
  107. example_inputs: List[torch.Tensor],
  108. cudagraphs=None,
  109. num_fixed=0,
  110. is_backward=False,
  111. graph_id=None,
  112. ):
  113. if is_tf32_warning_applicable(gm):
  114. _warn_tf32_disabled()
  115. if dynamo_utils.count_calls(gm.graph) == 0:
  116. return make_boxed_func(gm.forward)
  117. # lift the maximum depth of the Python interpreter stack
  118. # to adapt large/deep models
  119. sys.setrecursionlimit(max(sys.getrecursionlimit(), 2000))
  120. _step_logger()(
  121. logging.INFO,
  122. "torchinductor compiling "
  123. f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
  124. f"graph {graph_id}",
  125. )
  126. V.debug.fx_graph(gm, example_inputs)
  127. if cudagraphs is None:
  128. cudagraphs = config.triton.cudagraphs
  129. shape_env = _shape_env_from_inputs(example_inputs)
  130. fake_mode = fake_mode_from_tensors(
  131. example_inputs
  132. ) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
  133. with V.set_fake_mode(fake_mode):
  134. pattern_matcher.fx_passes(gm)
  135. V.debug.fx_graph_transformed(gm, example_inputs)
  136. graph = GraphLowering(
  137. gm,
  138. shape_env=shape_env,
  139. num_static_inputs=num_fixed,
  140. graph_id=graph_id,
  141. )
  142. with V.set_graph_handler(graph):
  143. graph.run(*example_inputs)
  144. compiled_fn = graph.compile_to_fn()
  145. if cudagraphs:
  146. complex_memory_overlap_inputs = any(
  147. complex_memory_overlap(t) for t in example_inputs
  148. )
  149. if (
  150. set(graph.device_types) == {"cuda"}
  151. and not graph.mutated_inputs
  152. and not has_incompatible_cudagraph_ops(gm)
  153. and not complex_memory_overlap_inputs
  154. ):
  155. compiled_fn = cudagraphify(
  156. compiled_fn, example_inputs, static_input_idxs=range(num_fixed)
  157. )
  158. else:
  159. BoxedBool.disable(cudagraphs)
  160. if len(set(graph.device_types)) > 1:
  161. developer_warning("skipping cudagraphs due to multiple devices")
  162. elif set(graph.device_types) == {"cuda"}:
  163. if graph.mutated_inputs:
  164. developer_warning("skipping cudagraphs due to input mutation")
  165. elif complex_memory_overlap_inputs:
  166. developer_warning(
  167. "skipping cudagraphs due to complex input striding"
  168. )
  169. result = align_inputs(compiled_fn, example_inputs, range(num_fixed))
  170. _step_logger()(
  171. logging.INFO,
  172. "torchinductor done compiling "
  173. f"{'BACKWARDS' if is_backward else 'FORWARDS'} "
  174. f"graph {graph_id}",
  175. )
  176. # aot autograd needs to know to pass in inputs as a list
  177. result._boxed_call = True
  178. return result
  179. def clone_preserve_strides(x):
  180. needed_size = (
  181. sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
  182. )
  183. buffer = torch.as_strided(x, (needed_size,), (1,)).clone()
  184. return torch.as_strided(buffer, x.size(), x.stride())
  185. def align_inputs(model, inputs, static_input_idxs=()):
  186. def is_aligned(storage_offset, dtype):
  187. return (storage_offset * get_dtype_size(dtype)) % ALIGNMENT == 0
  188. check_inputs = [
  189. i
  190. for i in range(len(inputs))
  191. if (
  192. i not in static_input_idxs
  193. or not is_aligned(inputs[i].storage_offset(), inputs[i].dtype)
  194. )
  195. and inputs[i].device.type == "cuda"
  196. ]
  197. if len(check_inputs) == 0:
  198. return model
  199. def run(new_inputs):
  200. for i in check_inputs:
  201. if new_inputs[i].data_ptr() % ALIGNMENT:
  202. new_inputs[i] = clone_preserve_strides(new_inputs[i])
  203. return model(new_inputs)
  204. return run
  205. @dynamo_utils.dynamo_timed
  206. def cudagraphify(model, inputs, static_input_idxs=()):
  207. # if using fake tensors, defer cudagraphs until we get real inputs at runtime
  208. if not any(isinstance(inp, FakeTensor) for inp in inputs):
  209. return cudagraphify_impl(model, inputs, static_input_idxs)
  210. compiled_fn = None
  211. def run(new_inputs):
  212. nonlocal compiled_fn
  213. if compiled_fn is None:
  214. with dynamo_utils.preserve_rng_state():
  215. compiled_fn = cudagraphify_impl(model, new_inputs, static_input_idxs)
  216. return compiled_fn(new_inputs)
  217. return run
  218. def remove_unaligned_input_idxs(inputs, static_input_idxs):
  219. """
  220. We require all inputs to be aligned, so introduce a copy for any
  221. that aren't.
  222. """
  223. aligned_static_input_idxs = {
  224. idx for idx in static_input_idxs if (inputs[idx].data_ptr() % ALIGNMENT) == 0
  225. }
  226. if len(aligned_static_input_idxs) != len(static_input_idxs):
  227. return aligned_static_input_idxs
  228. return static_input_idxs
  229. def cudagraphify_impl(model, inputs, static_input_idxs=()):
  230. """
  231. Assumes inputs[static_input_idxs[i]] are always the same memory address
  232. """
  233. static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs)
  234. def static_input(x):
  235. """
  236. Copy and input while preserving strides
  237. """
  238. # TODO(jansel): figure out why this version doesn't work:
  239. # return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device=x.device)
  240. needed_size = (
  241. sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
  242. )
  243. buffer = torch.zeros(needed_size, dtype=x.dtype, device=x.device)
  244. return torch.as_strided(buffer, x.size(), x.stride())
  245. assert isinstance(inputs, (list, tuple))
  246. static_inputs = [
  247. static_input(x) if idx not in static_input_idxs else x.detach()
  248. for idx, x in enumerate(inputs)
  249. ]
  250. inps_expanded_dims = [
  251. get_expanded_dims(x) if idx not in static_input_idxs else []
  252. for idx, x in enumerate(inputs)
  253. ]
  254. # warmup
  255. torch.cuda.synchronize()
  256. stream = torch.cuda.Stream()
  257. stream.wait_stream(torch.cuda.current_stream())
  258. # copy static_inputs because it will be cleared in model
  259. with torch.cuda.stream(stream):
  260. model(list(static_inputs))
  261. stream.synchronize()
  262. torch.cuda.current_stream().wait_stream(stream)
  263. torch.cuda.synchronize()
  264. # record
  265. graph = torch.cuda.CUDAGraph()
  266. with torch.cuda.graph(graph, stream=stream):
  267. static_outputs = model(list(static_inputs))
  268. if not isinstance(static_outputs, (list, tuple)):
  269. static_outputs = (static_outputs,)
  270. if config.size_asserts:
  271. def run(new_inputs):
  272. assert len(static_inputs) == len(new_inputs)
  273. for idx, (dst, src, expanded_dims) in enumerate(
  274. zip(static_inputs, new_inputs, inps_expanded_dims)
  275. ):
  276. if idx in static_input_idxs:
  277. assert dst.data_ptr() == src.data_ptr()
  278. else:
  279. # TODO - could make one single op of multiple slices
  280. # and avoid dispatch.
  281. # Could also pre-index the `dst` tensors
  282. dst = index_expanded_dims(dst, expanded_dims)
  283. src = index_expanded_dims(src, expanded_dims)
  284. dst.copy_(src)
  285. new_inputs.clear()
  286. graph.replay()
  287. return static_outputs
  288. else:
  289. copy_indices = [
  290. idx for idx in range(len(static_inputs)) if idx not in static_input_idxs
  291. ]
  292. def run(new_inputs):
  293. for idx in copy_indices:
  294. src = index_expanded_dims(static_inputs[idx], inps_expanded_dims[idx])
  295. dst = index_expanded_dims(new_inputs[idx], inps_expanded_dims[idx])
  296. dst.copy_(src)
  297. new_inputs.clear()
  298. graph.replay()
  299. return static_outputs
  300. return run
  301. def count_tangents(fx_g: torch.fx.GraphModule):
  302. """
  303. Infers which inputs are static for a backwards graph
  304. """
  305. def is_not_gradout(x):
  306. return "tangents" not in x.name
  307. arg_count = 0
  308. static_arg_idxs = []
  309. for n in fx_g.graph.nodes:
  310. if n.op == "placeholder":
  311. if is_not_gradout(n):
  312. static_arg_idxs.append(arg_count)
  313. arg_count += 1
  314. assert static_arg_idxs == list(range(len(static_arg_idxs)))
  315. return len(static_arg_idxs)
  316. _graph_counter = itertools.count(0)
  317. def compile_fx(
  318. model_: torch.fx.GraphModule,
  319. example_inputs_: List[torch.Tensor],
  320. inner_compile=compile_fx_inner,
  321. config_patches: Optional[Dict[str, Any]] = None,
  322. ):
  323. """Main entrypoint to a compile given FX graph"""
  324. if config_patches:
  325. with config.patch(config_patches):
  326. return compile_fx(
  327. model_,
  328. example_inputs_,
  329. # need extra layer of patching as backwards is compiled out of scope
  330. inner_compile=config.patch(config_patches)(inner_compile),
  331. )
  332. assert not config._raise_error_for_testing
  333. functorch.compile.config.use_functionalize = True
  334. functorch.compile.config.use_fake_tensor = True
  335. with overrides.patch_functions():
  336. model_ = overrides.replace_fx(model_)
  337. model_ = overrides.fuse_fx(model_, example_inputs_)
  338. num_example_inputs = len(example_inputs_)
  339. cudagraphs = BoxedBool(
  340. config.triton.cudagraphs and not dynamo_config.dynamic_shapes
  341. )
  342. graph_id = next(_graph_counter)
  343. @dynamo_utils.dynamo_timed
  344. def fw_compiler(model: torch.fx.GraphModule, example_inputs):
  345. fixed = len(example_inputs) - num_example_inputs
  346. # Why convert outplace op to inplace? Inductor can support inplace operations well and for custom
  347. # inplace ops which are lowered as ExternKernel, it is beneficial to performance when the inplace
  348. # implementation is used if available.
  349. model = convert_outplace_to_inplace(model)
  350. return inner_compile(
  351. model,
  352. example_inputs,
  353. num_fixed=fixed,
  354. cudagraphs=cudagraphs,
  355. graph_id=graph_id,
  356. )
  357. @dynamo_utils.dynamo_timed
  358. def bw_compiler(model: torch.fx.GraphModule, example_inputs):
  359. fixed = count_tangents(model)
  360. return inner_compile(
  361. model,
  362. example_inputs,
  363. num_fixed=fixed,
  364. cudagraphs=cudagraphs,
  365. is_backward=True,
  366. graph_id=graph_id,
  367. )
  368. with overrides.patch_functions():
  369. # TODO: can add logging before/after the call to create_aot_dispatcher_function
  370. # in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func
  371. # once torchdynamo is merged into pytorch
  372. return aot_autograd(
  373. fw_compiler=fw_compiler,
  374. bw_compiler=bw_compiler,
  375. decompositions=select_decomp_table(),
  376. partition_fn=functools.partial(
  377. min_cut_rematerialization_partition, compiler="inductor"
  378. ),
  379. keep_inference_input_mutations=True,
  380. )(model_, example_inputs_)
  381. def _shape_env_from_inputs(inputs):
  382. shape_env = None
  383. fake_mode = fake_mode_from_tensors(inputs)
  384. # TODO(voz): It would be nice to enable this assert, but there are lots of tests that
  385. # pass in real inputs for now.
  386. # if len(inputs) > 0:
  387. # assert fake_mode is not None, breakpoint()
  388. if fake_mode is not None:
  389. return fake_mode.shape_env
  390. # TODO(voz): Should we always have one anyway?
  391. return None