nvfuser_executor.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. import operator
  2. from copy import deepcopy
  3. from dataclasses import dataclass
  4. from functools import lru_cache
  5. from types import MappingProxyType
  6. from warnings import warn
  7. import torch
  8. import torch.overrides
  9. from torch._prims_common import (
  10. _torch_dtype_to_nvfuser_dtype_map,
  11. getnvFuserDtype,
  12. Number,
  13. number_type,
  14. )
  15. from torch.fx import GraphModule
  16. from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
  17. from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
  18. if torch.cuda.is_available():
  19. from nvfuser._C import ( # type: ignore[import]
  20. DataType,
  21. Fusion,
  22. FusionDefinition,
  23. Tensor,
  24. )
  25. else:
  26. DataType = None
  27. import os
  28. @lru_cache(None)
  29. def get_nvprim_dump_nvtx():
  30. return os.getenv("PYTORCH_NVFUSER_DUMP_NVTX")
  31. DEFAULT_NVFUSER_PYTHON_CONFIG = MappingProxyType(
  32. {
  33. "use_python_fusion_cache": True,
  34. "allow_single_op_fusion": False,
  35. }
  36. )
  37. # nvFuserTensorTemplate and nvFuserScalarTemplate are helper objects
  38. # for cached construction of the nvFuser's Fusion
  39. # TODO: change what is stored in the cache for nvFuser's Tensor objects
  40. # https://github.com/pytorch/pytorch/issues/80551
  41. @dataclass(frozen=True)
  42. class nvFuserTensorTemplate:
  43. symbolic_shape: tuple
  44. contiguity: tuple
  45. dtype: DataType
  46. is_cpu: bool
  47. @dataclass(frozen=True)
  48. class nvFuserScalarTemplate:
  49. dtype: DataType
  50. @lru_cache(maxsize=2048)
  51. def compute_symbolic_shape(shape):
  52. """Computes the symbolic shape of a tensor.
  53. nvFuser specializes on size-1 dimensions as broadcasted dimensions.
  54. -1 is used to represent any size."""
  55. return tuple(1 if s == 1 else -1 for s in shape)
  56. @lru_cache(maxsize=2048)
  57. def compute_contiguity(shape, strides):
  58. """Computes the contiguity information to simplify internal indexing.
  59. Contiguous dimensions are represented by True, strided dimensions
  60. are represented by False.
  61. """
  62. from nvfuser._C import compute_contiguity
  63. return compute_contiguity(shape, strides)
  64. def to_nvfuser_template_args(args):
  65. def to_nvfuser(arg):
  66. if isinstance(arg, torch.Tensor):
  67. return nvFuserTensorTemplate(
  68. compute_symbolic_shape(arg.size()),
  69. compute_contiguity(arg.size(), arg.stride()),
  70. getnvFuserDtype(arg.dtype),
  71. arg.is_cpu, # type: ignore[attr-defined]
  72. )
  73. elif isinstance(arg, Number):
  74. return nvFuserScalarTemplate(getnvFuserDtype(number_type(arg)))
  75. else:
  76. return arg
  77. return tree_map(to_nvfuser, args)
  78. def _any_get_attr_used(call_function_nodes):
  79. return any(
  80. filter(
  81. # bug in mypy https://github.com/python/mypy/issues/12682
  82. lambda n: any( # type: ignore[arg-type]
  83. a.op == "get_attr" for a in n.args if isinstance(a, torch.fx.Node) # type: ignore[attr-defined]
  84. ),
  85. call_function_nodes,
  86. )
  87. )
  88. # MyPy bug: https://github.com/python/mypy/issues/5107
  89. @lru_cache(maxsize=1024) # type: ignore[arg-type]
  90. def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
  91. if not torch.cuda.is_available():
  92. raise RuntimeError(
  93. "Attempting to use nvFuser trace executor but CUDA is not available!"
  94. )
  95. # Everything in the graph must support nvfuser
  96. for node in gm.graph.nodes:
  97. if node.op == "call_function" and node.target == operator.getitem:
  98. continue
  99. if (
  100. node.op == "call_function"
  101. and getattr(node.target, "impl_nvfuser", None) is None
  102. ):
  103. raise ValueError(
  104. "All call_function nodes in the graph must support nvfuser. "
  105. f"Node {node} with target {node.target} does not support nvfuser"
  106. )
  107. graph_input_nodes = list(filter(lambda n: n.op == "placeholder", gm.graph.nodes))
  108. call_function_nodes = list(
  109. filter(lambda n: n.op == "call_function", gm.graph.nodes)
  110. )
  111. assert len(graph_input_nodes) == len(
  112. nv_args_templates
  113. ), "Number of placeholder nodes in the graph must match number of args"
  114. assert len(nv_args_templates) > 0, "There must be at least one argument"
  115. assert (
  116. len(call_function_nodes) > 0
  117. ), "Graph must contain at least one call_function node"
  118. assert not _any_get_attr_used(
  119. call_function_nodes
  120. ), "Constant tensors that are saved in the graph and used as arguments are not supported yet"
  121. # Checking output dtypes
  122. output_node = next(filter(lambda n: n.op == "output", gm.graph.nodes))
  123. orig_flat_out, _ = tree_flatten(output_node.args[0])
  124. fusion = Fusion()
  125. with FusionDefinition(fusion) as fd:
  126. def _to_nvfuser_constant(arg):
  127. if isinstance(arg, Number):
  128. return fd.define_constant(arg)
  129. else:
  130. return arg
  131. class FusionInterpreter(torch.fx.Interpreter):
  132. def run_node(self, node):
  133. # Squeeze requires original shape of args[0]
  134. if node.target in [
  135. torch.ops.nvprims.squeeze,
  136. torch.ops.nvprims.squeeze.default,
  137. ]:
  138. original_shape = list(node.args[0].meta["tensor_meta"].shape)
  139. assert len(node.args) == 2
  140. args, kwargs = self.fetch_args_kwargs_from_env(node)
  141. args = [args[0], original_shape, args[1]]
  142. return self.call_function(node.target, args, node.kwargs)
  143. if node.target in [
  144. torch.ops.nvprims.native_batch_norm,
  145. torch.ops.nvprims.native_batch_norm.default,
  146. ]:
  147. args, kwargs = self.fetch_args_kwargs_from_env(node)
  148. assert len(args) == 8
  149. training = args[5]
  150. args6_end = tuple(map(_to_nvfuser_constant, args[6:]))
  151. args = args[:5] + (training,) + args6_end
  152. return node.target.impl_nvfuser(fd, *args, **kwargs)
  153. return super().run_node(node)
  154. def call_function(self, target, args, kwargs):
  155. # This handles tuple unpacking
  156. if target == operator.getitem:
  157. assert isinstance(args[0], tuple)
  158. return target(*args, **kwargs)
  159. args = tuple(map(_to_nvfuser_constant, args))
  160. target = target.impl_nvfuser
  161. args = (fd,) + args
  162. return target(*args, **kwargs)
  163. def output(self, target, args, kwargs):
  164. flat_out, unflatten_spec = tree_flatten(args[0])
  165. for o, orig_o in zip(flat_out, orig_flat_out):
  166. # casting outputs to the original data type
  167. # ensures outputs produced by fusion would always agree with original GraphModule
  168. out_dtype = _torch_dtype_to_nvfuser_dtype_map.get(orig_o.meta["tensor_meta"].dtype) # type: ignore[union-attr]
  169. assert isinstance(
  170. o, Tensor
  171. ), "output from codegen has to be tensor type"
  172. fd.add_output(fd.ops.cast(o, dtype=out_dtype))
  173. return args[0]
  174. def templates_to_nvfuser_inputs(arg):
  175. if isinstance(arg, nvFuserTensorTemplate):
  176. x = fd.define_tensor(
  177. arg.symbolic_shape, arg.contiguity, arg.dtype, arg.is_cpu
  178. )
  179. return x
  180. elif isinstance(arg, nvFuserScalarTemplate):
  181. x = fd.define_scalar(arg.dtype)
  182. return x
  183. else:
  184. return arg
  185. # Transforms graph to call nvfuser lowerings
  186. nv_args = tuple(map(templates_to_nvfuser_inputs, nv_args_templates))
  187. out = FusionInterpreter(gm).run(*nv_args)
  188. flat_out, unflatten_spec = tree_flatten(out)
  189. return fusion, unflatten_spec
  190. def nvfuser_execute(gm: GraphModule, *args, executor_parameters=None):
  191. executor_parameters = executor_parameters or DEFAULT_NVFUSER_PYTHON_CONFIG
  192. flat_args, _ = tree_flatten(args)
  193. # check for cuda only fusion
  194. if any(isinstance(arg, torch.Tensor) and arg.is_cuda for arg in flat_args) and all( # type: ignore[attr-defined]
  195. (
  196. not isinstance(arg, torch.Tensor)
  197. or (arg.is_cpu and arg.ndim == 0) # type: ignore[attr-defined]
  198. or arg.is_cuda # type: ignore[attr-defined]
  199. )
  200. for arg in flat_args
  201. ):
  202. # Construction of the fusion is expensive and cached based on the GraphModule
  203. # and symbolic nvFuser args.
  204. nv_template_args = to_nvfuser_template_args(flat_args)
  205. use_cache = executor_parameters.get(
  206. "use_python_fusion_cache",
  207. DEFAULT_NVFUSER_PYTHON_CONFIG["use_python_fusion_cache"],
  208. )
  209. if use_cache:
  210. fusion, unflatten_spec = make_nvfuser_fusion(gm, *nv_template_args) # type: ignore[misc]
  211. else:
  212. fusion, unflatten_spec = make_nvfuser_fusion.__wrapped__(gm, *nv_template_args) # type: ignore[misc]
  213. # Inputs to fusion.execute correspond to the same template/symbolic inputs
  214. # marked with `define_tensor/scalar`
  215. concrete_fusion_inputs = tuple(
  216. arg for arg in flat_args if isinstance(arg, (torch.Tensor, Number))
  217. )
  218. if get_nvprim_dump_nvtx():
  219. torch.cuda.nvtx.range_push(
  220. "fusion: {0}, graph: {1}".format(
  221. fusion.id(),
  222. str(
  223. [
  224. {
  225. "op": n.op,
  226. "name": n.name,
  227. "args": n.args,
  228. "kwargs": n.kwargs,
  229. }
  230. for n in gm.graph.nodes
  231. ]
  232. ),
  233. )
  234. )
  235. result = tree_unflatten(
  236. fusion.execute(concrete_fusion_inputs), # type: ignore[has-type]
  237. unflatten_spec, # type: ignore[has-type]
  238. )
  239. if get_nvprim_dump_nvtx():
  240. torch.cuda.nvtx.range_pop()
  241. return result
  242. else:
  243. warn(
  244. "nvfuser_executor is executed with non-cuda args, fallback to aten executor"
  245. )
  246. return gm.forward(*args)
  247. class NvfuserPrimOperatorSupport(torch.fx.passes.operator_support.OperatorSupport):
  248. def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
  249. # special case to stop lowering to nvprim when converting to an unsupported type
  250. if (
  251. node.op == "call_function"
  252. and node.target == torch.ops.nvprims.convert_element_type.default
  253. ):
  254. return (
  255. _torch_dtype_to_nvfuser_dtype_map.get(node.args[1]) is not None
  256. and _torch_dtype_to_nvfuser_dtype_map.get(
  257. node.args[0].meta["tensor_meta"].dtype # type: ignore[union-attr]
  258. )
  259. is not None
  260. )
  261. return node.op == "call_function" and (
  262. getattr(node.target, "impl_nvfuser", None) is not None
  263. or node.target == operator.getitem
  264. )
  265. class PartitionedInterpreter(torch.fx.Interpreter):
  266. def call_module(self, target, args, kwargs):
  267. assert isinstance(target, str)
  268. assert len(kwargs) == 0
  269. submod = self.fetch_attr(target)
  270. # CapabilityBasedPartitioner hardcodes the name of the subgraphs with supported_ops as "fused_" + subgraph id
  271. if target.startswith("fused_"):
  272. return nvfuser_execute(submod, *args)
  273. else:
  274. return super().call_module(target, args, kwargs)
  275. class NvfuserGraphModule(torch.nn.Module):
  276. def __init__(self, gm, use_python_fusion_cache):
  277. super().__init__()
  278. self.gm = gm
  279. self.executor_parameters = {"use_python_fusion_cache": use_python_fusion_cache}
  280. def __call__(self, *args):
  281. return nvfuser_execute(
  282. self.gm, *args, executor_parameters=self.executor_parameters
  283. )
  284. # A set of operators that are supported by nvFuser
  285. # but should not form a fusion group solely on their own
  286. _non_compute_ops = [
  287. "torch.ops." + str(getattr(torch.ops.nvprims, prim).default)
  288. for prim in dir(torch.ops.nvprims)
  289. if isinstance(getattr(torch.ops.nvprims, prim), torch._ops.OpOverloadPacket)
  290. and getattr(torch.ops.nvprims, prim).return_type
  291. == torch._prims_common.RETURN_TYPE.VIEW
  292. ]
  293. _allowed_single_node_partition_ops = [
  294. "torch.ops.nvprims.native_batch_norm.default",
  295. "torch.ops.nvprims.var_mean.default",
  296. "torch.ops.nvprims.var_mean.main",
  297. ]
  298. def _remove_empty_like_fill(gm: GraphModule):
  299. # Remove empty_like + fill nodes that prevent lowering to nvprims
  300. # This is a workaround for nonoptimal traces of C++ code `(1 - tensor)`
  301. # https://github.com/pytorch/pytorch/issues/86612
  302. def pattern(scalar, tensor):
  303. # pattern for C++ trace of `scalar - tensor`. We are looking for the
  304. # pattern of aten and nvprims.sub specifically because we want to remove
  305. # the empty_like + fill nodes after lowering of AOT Autograd trace to
  306. # nvprims In the future, nvFuser might support fill, and empty_like and
  307. # this workaround can be removed.
  308. empty_like = torch.ops.aten.empty_like.default(
  309. tensor, memory_format=torch.preserve_format
  310. )
  311. fill = torch.ops.aten.fill.Scalar(empty_like, scalar)
  312. sub = torch.ops.nvprims.sub.default(fill, tensor)
  313. return sub
  314. def replacement(scalar, tensor):
  315. return torch.ops.nvprims.sub.default(scalar, tensor)
  316. torch.fx.replace_pattern(gm, pattern, replacement)
  317. return gm
  318. # MyPy bug: https://github.com/python/mypy/issues/5107
  319. @lru_cache(maxsize=1024) # type: ignore[arg-type]
  320. def maybe_partition_graph(
  321. gm: GraphModule, allow_single_op_fusion: bool, use_python_fusion_cache: bool
  322. ):
  323. gm = _remove_empty_like_fill(gm)
  324. supported_ops = NvfuserPrimOperatorSupport()
  325. call_function_nodes = list(
  326. filter(lambda n: n.op == "call_function", gm.graph.nodes)
  327. )
  328. # the graph is partitioned only if at least one node is not supported by nvFuser
  329. any_unsupported = any(
  330. not supported_ops.is_node_supported(None, node) for node in call_function_nodes
  331. )
  332. any_unsupported |= len(call_function_nodes) == 0
  333. # When there are constant tensors in the graph, we can't partition it
  334. # because deepcopy fails. Here we just return the original graph to be
  335. # executed by eager mode
  336. # https://github.com/pytorch/pytorch/issues/84415
  337. if (
  338. _any_get_attr_used(call_function_nodes)
  339. or len(list(filter(lambda n: n.op == "placeholder", gm.graph.nodes))) == 0
  340. ):
  341. return gm, True
  342. if any_unsupported:
  343. # CapabilityBasedPartitioner modifies the graph in-place so we need to make a copy of the graph
  344. gm = deepcopy(gm)
  345. partitioner = CapabilityBasedPartitioner(
  346. gm,
  347. supported_ops,
  348. allows_single_node_partition=allow_single_op_fusion,
  349. non_compute_ops=_non_compute_ops,
  350. allowed_single_node_partition_ops=_allowed_single_node_partition_ops,
  351. )
  352. partitions = partitioner.propose_partitions()
  353. partitioner.remove_bookend_non_compute_ops(partitions)
  354. if len(partitions) == 0:
  355. warn(
  356. "No partition found for the graph. "
  357. + "This is likely because the graph is not supported by nvFuser. "
  358. + "Please use the eager ATen mode to execute the graph.",
  359. category=RuntimeWarning,
  360. )
  361. partitioned_graph = partitioner.fuse_partitions(partitions)
  362. # Replacing graph's fused submodules with a wrapper module with
  363. # __call__() method that calls nvfuser_execute.
  364. # This avoids the need to call the interpreter on the graph
  365. for node in partitioned_graph.graph.nodes:
  366. # TODO: use a better way to identify fused submodule
  367. if node.op == "call_module" and "fused_" in node.name:
  368. nvfuser_submodule = getattr(partitioned_graph, node.name)
  369. partitioned_graph.delete_submodule(node.target)
  370. gm.add_submodule(
  371. node.target,
  372. NvfuserGraphModule(nvfuser_submodule, use_python_fusion_cache),
  373. )
  374. # Go through the graph and replace all the nodes that were converted to
  375. # nvprims but won't be sent to nvFuser with a call to PyTorch's eager
  376. # mode. This is necessary because torch.ops.* have higher overhead than
  377. # calling the eager mode directly.
  378. for node in partitioned_graph.graph.nodes:
  379. if node.op == "call_function" and str(node.target).startswith("nvprims."):
  380. if getattr(node.target, "impl_aten", None) is not None:
  381. node.target = node.target.impl_aten
  382. partitioned_graph.graph.eliminate_dead_code()
  383. partitioned_graph.recompile()
  384. return partitioned_graph, any_unsupported
  385. else:
  386. return gm, any_unsupported
  387. class NVTXInterpreter(torch.fx.Interpreter):
  388. def run_node(self, n):
  389. torch.cuda.nvtx.range_push(
  390. "name: {0}, args: {1}, op: {2}, kwargs: {3}".format(
  391. n.name, n.args, n.op, n.kwargs
  392. )
  393. )
  394. result = super().run_node(n)
  395. torch.cuda.nvtx.range_pop()
  396. return result
  397. def nvfuser_execute_partitioned(gm: GraphModule, *args, executor_parameters=None):
  398. executor_parameters = executor_parameters or DEFAULT_NVFUSER_PYTHON_CONFIG
  399. # maybe_partition_graph function is cached so we can't use non-hashable arguments
  400. allow_single_op_fusion = executor_parameters.get(
  401. "allow_single_op_fusion",
  402. DEFAULT_NVFUSER_PYTHON_CONFIG["allow_single_op_fusion"],
  403. )
  404. use_python_fusion_cache = executor_parameters.get(
  405. "use_python_fusion_cache",
  406. DEFAULT_NVFUSER_PYTHON_CONFIG["use_python_fusion_cache"],
  407. )
  408. # When possible it's better to use nvfuser_execute directly
  409. # because it avoids GraphModule's overhead
  410. gm, is_partitioned = maybe_partition_graph(
  411. gm,
  412. allow_single_op_fusion=allow_single_op_fusion,
  413. use_python_fusion_cache=use_python_fusion_cache,
  414. )
  415. if is_partitioned:
  416. if get_nvprim_dump_nvtx():
  417. return NVTXInterpreter(gm).run(*args)
  418. else:
  419. return gm(*args)
  420. else:
  421. return nvfuser_execute(gm, *args, executor_parameters=executor_parameters)