context.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. import functools
  2. from contextlib import nullcontext
  3. from typing import Any, Callable, Dict, Sequence
  4. from warnings import warn
  5. import torch
  6. import torch._decomp
  7. import torch._prims
  8. import torch._refs
  9. import torch._refs.nn
  10. import torch._refs.nn.functional
  11. import torch._refs.special
  12. import torch.overrides
  13. from torch._prims.nvfuser_executor import NvfuserPrimOperatorSupport
  14. from torch._prims_common import torch_function_passthrough
  15. from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule
  16. @functools.lru_cache(None)
  17. def torch_to_refs_map():
  18. """
  19. Mapping of torch API functions to torch._refs functions.
  20. E.g. torch_to_refs_map()[torch.add] == torch._refs.add
  21. """
  22. modules = [
  23. (torch, torch._refs),
  24. (torch.nn, torch._refs.nn),
  25. (torch.nn.functional, torch._refs.nn.functional),
  26. (torch.special, torch._refs.special),
  27. (torch.fft, torch._refs.fft),
  28. (torch.linalg, torch._refs.linalg),
  29. ]
  30. r: Dict[Any, Any] = {
  31. torch.Tensor.__invert__: torch._refs.bitwise_not,
  32. torch.Tensor.__xor__: torch._refs.bitwise_xor,
  33. torch.Tensor.__and__: torch._refs.bitwise_and,
  34. torch.Tensor.__or__: torch._refs.bitwise_or,
  35. torch.Tensor.__eq__: torch._refs.eq,
  36. torch.Tensor.__rsub__: torch._refs.rsub,
  37. torch.Tensor.__rtruediv__: torch._refs.rtruediv,
  38. torch.Tensor.__floordiv__: torch._refs.floor_divide,
  39. torch.Tensor.__rfloordiv__: torch._refs.rfloordiv,
  40. torch.Tensor.__pow__: torch._refs.pow,
  41. torch.Tensor.__rpow__: torch._refs.rpow,
  42. torch.Tensor.new_empty: torch._refs.new_empty,
  43. torch.Tensor.new_full: torch._refs.new_full,
  44. torch.Tensor.new_zeros: torch._refs.new_zeros,
  45. torch.Tensor.new_ones: torch._refs.new_ones,
  46. torch.Tensor.fill_: torch._refs.fill_,
  47. torch.Tensor.zero_: torch._refs.zero_,
  48. torch.Tensor.to: torch._refs.to,
  49. torch.Tensor.sum_to_size: torch._refs.sum_to_size,
  50. # TODO: Should these methods be mapped some other way?
  51. torch.Tensor.copy_: torch._prims.copy_to,
  52. torch.Tensor.resize: torch._prims.resize,
  53. }
  54. for mod_torch, mod_refs in modules:
  55. for s in mod_refs.__all__: # type: ignore[attr-defined]
  56. r[mod_torch.__dict__.get(s)] = mod_refs.__dict__.get(s)
  57. # Support remapping torch.Tensor.foo to _refs.foo
  58. for s in dir(torch.Tensor):
  59. if s in torch._refs.__all__:
  60. r[getattr(torch.Tensor, s)] = torch._refs.__dict__.get(s)
  61. # Support conversions
  62. for s in torch._refs._conversions.__all__:
  63. tensor_attr = getattr(torch.Tensor, s, None) or getattr(torch, s)
  64. r[tensor_attr] = torch._refs._conversions.__dict__.get(s)
  65. return r
  66. @functools.lru_cache(None)
  67. def all_prims():
  68. """
  69. Set of all prim functions, e.g., torch._prims.add in all_prims()
  70. """
  71. return {torch._prims.__dict__.get(s) for s in torch._prims.__all__}
  72. class NvfuserPrimsMode(torch.overrides.TorchFunctionMode):
  73. """
  74. Switches the interpretation of torch.ops.prims.* functions to
  75. use nvFuser's prims in torch.ops.nvprims.*
  76. >>> # xdoctest: +SKIP("undefined vars")
  77. >>> with NvfuserPrimsMode():
  78. ... torch.ops.prims.add(x, y) # calls torch.ops.nvprims.add(x, y)
  79. By default, this context manager will fall back on the torch.ops.prims* if the
  80. nvprim does not exist.
  81. It's possible to skip certain prims by passing their names to the skip_ops
  82. argument. skip_ops is expected to be a sequence of strings, e.g.,
  83. ["prims.add.default"] In order to check the expected name of a prim, one can
  84. use the `torch.overrides.resolve_name`.
  85. >>> # xdoctest: +SKIP("undefined vars")
  86. >>> with NvfuserPrimsMode(skips_ops=("prims.add.default")):
  87. ... torch.ops.prims.add.default(x, y) # does not call torch.ops.nvprims.add.default(x, y)
  88. """
  89. def __init__(self, *, skip_ops=()):
  90. self.skip_ops = skip_ops
  91. def __torch_function__(
  92. self,
  93. orig_func: Callable,
  94. types: Sequence,
  95. args: Sequence[Any] = (),
  96. kwargs: Dict = None,
  97. ):
  98. if kwargs is None:
  99. kwargs = {}
  100. # If the function is in the skip list, then we don't want to
  101. # remap it to the nvprims.
  102. if torch.overrides.resolve_name(orig_func) in self.skip_ops:
  103. return orig_func(*args, **kwargs)
  104. if isinstance(orig_func, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
  105. namespace = str(orig_func).split(".")[0]
  106. name = str(orig_func).split(".")[1]
  107. if namespace == "prims":
  108. nvfunc = getattr(torch.ops.nvprims, name, None)
  109. if nvfunc is not None:
  110. return nvfunc(*args, **kwargs)
  111. return orig_func(*args, **kwargs)
  112. class TorchRefsMode(torch.overrides.TorchFunctionMode):
  113. """
  114. Switches the interpretation of torch.* functions and Tensor methods to
  115. use PrimTorch refs in torch._refs. (Direct calls to _refs are unaffected.)
  116. >>> # xdoctest: +SKIP
  117. >>> with TorchRefsMode():
  118. ... torch.add(x, y) # calls torch._refs.add(x, y)
  119. By default, this context manager will fall back on the torch.* if the
  120. ref does not exist; set strict=True to error if this occurs.
  121. If the ref exists we still would like to fall back on the torch.* sometimes,
  122. this behavior can be customized by passing a function to should_fallback_fn.
  123. """
  124. def __init__(
  125. self,
  126. strict=False,
  127. should_fallback_fn=lambda *_: False,
  128. prims_mode_cls=nullcontext,
  129. ):
  130. self.strict = strict
  131. self.should_fallback_fn = should_fallback_fn
  132. self.prims_mode_cls = prims_mode_cls
  133. def __torch_function__(
  134. self,
  135. orig_func: Callable,
  136. types: Sequence,
  137. args: Sequence[Any] = (),
  138. kwargs: Dict = None,
  139. ):
  140. if kwargs is None:
  141. kwargs = {}
  142. # For primitive operations, run them as is without interception
  143. # Unless we are in prims_mode, in which case we want to use nvprims
  144. if orig_func in torch_function_passthrough or orig_func in all_prims():
  145. with self.prims_mode_cls():
  146. return orig_func(*args, **kwargs)
  147. mapping = torch_to_refs_map()
  148. func = mapping.get(orig_func, None)
  149. # For torch.ops.aten.*, use registered decompositions from torch._decomp
  150. # torch._decomp.decomposition_table provides a mapping from
  151. # torch.ops.aten.* to torch._refs or torch._decomp.decompositions
  152. # implementations.
  153. # There're other ways to implement this functionality,
  154. # see https://github.com/pytorch/pytorch/pull/82657#discussion_r939776417
  155. if func is None and isinstance(orig_func, torch._ops.OpOverload):
  156. func = torch._decomp.decomposition_table.get(orig_func, None)
  157. if func is not None:
  158. # If the ref exists query whether we should use it or not
  159. if self.should_fallback_fn(self, orig_func, func, args, kwargs):
  160. return orig_func(*args, **kwargs)
  161. # torch calls inside func should be interpreted as refs calls
  162. with self:
  163. return func(*args, **kwargs)
  164. if self.strict:
  165. raise RuntimeError(
  166. f"no _refs support for {torch.overrides.resolve_name(orig_func)}"
  167. )
  168. return orig_func(*args, **kwargs)
  169. def _is_node_supported_nvfuser(node):
  170. return (
  171. node.op == "call_function"
  172. and getattr(node.target, "impl_nvfuser", None) is not None
  173. )
  174. def _is_func_unsupported_nvfuser(
  175. torch_function_mode, orig_func, func, args, kwargs, *, skip_ops=()
  176. ):
  177. """
  178. This function traces the `func` under `torch_function_mode` and checks if
  179. any of the traced nodes are not supported by nvFuser. If so, we should
  180. fallback to the original function.
  181. `skip_ops` argument is expected to be a list of strings of function names
  182. that would match with `torch.overrides.resolve_name`.
  183. Args:
  184. torch_function_mode: The torch_function_mode context manager. orig_func:
  185. The original function, its name will be used to check if
  186. it should be skipped.
  187. func: The function to be traced. args: The args to be passed to the
  188. function. kwargs: The kwargs to be passed to the function.
  189. Keyword args:
  190. skip_ops: A list of ops to skip when checking if the function is
  191. supported.
  192. """
  193. # One supported case is easy to check: if the resolved name of the original
  194. # function in the skip list, skip it.
  195. if torch.overrides.resolve_name(orig_func) in skip_ops:
  196. return True
  197. with torch_function_mode:
  198. try:
  199. gm = get_isolated_graphmodule(func, args, kwargs)
  200. except Exception as e:
  201. warn(
  202. "get_isolated_graphmodule failed on decomposition: "
  203. + func.__name__
  204. + " with error message: "
  205. + str(e)
  206. )
  207. # returns unsupported when tracing fails.
  208. return True
  209. supported_ops = NvfuserPrimOperatorSupport()
  210. call_function_nodes = filter(lambda n: n.op == "call_function", gm.graph.nodes)
  211. any_unsupported = any(
  212. not supported_ops.is_node_supported(None, node) for node in call_function_nodes
  213. )
  214. return any_unsupported
  215. class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
  216. def __init__(self, *, skip_ops=()):
  217. aten_ops_to_skip = (
  218. "aten._log_softmax.default",
  219. "aten._log_softmax_backward_data.default",
  220. "aten.expand.default",
  221. )
  222. self.skip_ops = tuple(skip_ops) + aten_ops_to_skip
  223. super().__init__(
  224. strict=False,
  225. should_fallback_fn=functools.partial(
  226. _is_func_unsupported_nvfuser,
  227. skip_ops=tuple(skip_ops) + aten_ops_to_skip,
  228. ),
  229. prims_mode_cls=functools.partial(NvfuserPrimsMode, skip_ops=skip_ops),
  230. )
  231. # TODO: remove this once version from _decomp/decompositions.py is working
  232. # with this context manager
  233. # This is a workaround for AOT Autograd graphs
  234. def _cudnn_batch_norm(
  235. self,
  236. input,
  237. weight,
  238. bias,
  239. running_mean,
  240. running_var,
  241. training,
  242. exponential_average_factor,
  243. epsilon,
  244. ):
  245. a, b, c = torch.ops.nvprims.native_batch_norm(
  246. input,
  247. weight,
  248. bias,
  249. running_mean,
  250. running_var,
  251. training,
  252. exponential_average_factor,
  253. epsilon,
  254. )
  255. if training:
  256. return (a, b, c, input.new_zeros((0,), dtype=torch.uint8))
  257. return (
  258. a,
  259. weight.new_zeros((0,)),
  260. weight.new_zeros((0,)),
  261. input.new_zeros((0,), dtype=torch.uint8),
  262. )
  263. # This is a workaround for AOT Autograd graphs
  264. def _cudnn_batch_norm_backward(
  265. self,
  266. input,
  267. grad_output,
  268. weight,
  269. running_mean,
  270. running_var,
  271. save_mean,
  272. save_var,
  273. epsilon,
  274. reserveSpace,
  275. ):
  276. func = torch._decomp.decomposition_table[
  277. torch.ops.aten.native_batch_norm_backward.default
  278. ]
  279. return func(
  280. grad_output,
  281. input,
  282. weight,
  283. running_mean,
  284. running_var,
  285. save_mean,
  286. save_var,
  287. True,
  288. epsilon,
  289. [True, True, True],
  290. )
  291. def _is_var_mean(self, func):
  292. return "torch.var_mean" == torch.overrides.resolve_name(func) or (
  293. (isinstance(func, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)))
  294. and "aten.var_mean" in str(func)
  295. )
  296. def _is_view_or_reshape(self, func):
  297. allowed_ops = {
  298. "torch.Tensor.view",
  299. "torch.Tensor.reshape",
  300. "torch.view_copy",
  301. "torch.reshape",
  302. "aten.view.default",
  303. "aten._unsafe_view.default",
  304. "aten.view_copy.default",
  305. } - set(self.skip_ops)
  306. return torch.overrides.resolve_name(func) in allowed_ops
  307. def _is_native_batch_norm(self, func):
  308. return "torch.native_batch_norm" == torch.overrides.resolve_name(func) or (
  309. func == torch.ops.aten.native_batch_norm.default
  310. or func == torch.ops.aten.native_batch_norm
  311. )
  312. def _is_rand_like(self, func):
  313. result = "torch.rand_like" == torch.overrides.resolve_name(func) or (
  314. func == torch.ops.aten.rand_like or func == torch.ops.aten.rand_like.default
  315. )
  316. return result
  317. def _is_full(self, func):
  318. result = "torch.full" == torch.overrides.resolve_name(func) or (
  319. func
  320. in [
  321. torch.ops.aten.full,
  322. torch.ops.aten.full.names,
  323. ]
  324. )
  325. return result
  326. def __torch_function__(
  327. self,
  328. orig_func: Callable,
  329. types: Sequence,
  330. args: Sequence[Any] = (),
  331. kwargs: Dict = None,
  332. ):
  333. if kwargs is None:
  334. kwargs = {}
  335. # First we intercept calls for nvfuser-specific prims bypassing generic torch._refs
  336. if self._is_var_mean(orig_func):
  337. return torch.ops.nvprims.var_mean(*args, **kwargs)
  338. if (
  339. orig_func == torch.ops.aten.cudnn_batch_norm.default
  340. or orig_func == torch.ops.aten.cudnn_batch_norm
  341. ):
  342. with self:
  343. return self._cudnn_batch_norm(*args, **kwargs)
  344. # A workaround for AOT Autograd graphs
  345. # See https://github.com/pytorch/pytorch/pull/86115#issue-1394883782
  346. if (
  347. orig_func == torch.ops.aten.cudnn_batch_norm_backward.default
  348. or orig_func == torch.ops.aten.cudnn_batch_norm_backward
  349. ):
  350. with self:
  351. return self._cudnn_batch_norm_backward(*args, **kwargs)
  352. if self._is_view_or_reshape(orig_func):
  353. a, *shape = args
  354. shape = torch._prims_common.extract_shape_from_varargs(
  355. shape, validate=False
  356. ) # type: ignore[assignment]
  357. if len(kwargs) > 0:
  358. warn("view has ignored kwargs!")
  359. return torch.ops.nvprims.view(a, shape)
  360. if orig_func == torch.ops.aten._reshape_alias.default:
  361. a, shape, stride = args
  362. if len(kwargs) > 0:
  363. warn("view has ignored kwargs!")
  364. return torch.ops.nvprims.view(a, shape)
  365. if self._is_native_batch_norm(orig_func):
  366. return torch.ops.nvprims.native_batch_norm(*args, **kwargs)
  367. if self._is_rand_like(orig_func):
  368. if len(kwargs) > 0:
  369. warn("rand_like has ignored kwargs!")
  370. return torch.ops.nvprims.rand_like(*args)
  371. if self._is_full(orig_func):
  372. return torch.ops.nvprims.full(*args, **kwargs)
  373. # Then we use TorchRefsMode to interpret the rest
  374. return super().__torch_function__(orig_func, types, args, kwargs)