nvfuser_prims.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831
  1. # Module for defining "primitive" operations executable by the nvFuser. This
  2. # list exists to decouple main set of primitives from the ones that provide a
  3. # lowering of the op to nvFuser’s Python interface. Mostly torch.ops.nvprims is
  4. # a subset of the primitives in torch.ops.prims, but some additional primitives
  5. # can be added in the future for the corresponding higher-level torch/aten
  6. # functions.
  7. from typing import Any, Dict, Optional, Tuple
  8. import torch
  9. import torch._prims_common as utils
  10. from torch._prims_common import (
  11. DimsSequenceType,
  12. elementwise_dtypes,
  13. ELEMENTWISE_TYPE_PROMOTION_KIND,
  14. getnvFuserDtype,
  15. make_contiguous_strides_for,
  16. NumberType,
  17. ShapeType,
  18. TensorLikeType,
  19. )
  20. from torch._prims_common.wrappers import (
  21. _maybe_convert_to_dtype,
  22. backwards_not_supported,
  23. elementwise_type_promotion_wrapper,
  24. )
  25. nvprim_namespace = "nvprims"
  26. nvprim = torch.library.Library(nvprim_namespace, "DEF")
  27. nvprim_impl = torch.library.Library(
  28. nvprim_namespace, "IMPL", "CompositeExplicitAutograd"
  29. )
  30. nvprim_implicit_impl = torch.library.Library(
  31. nvprim_namespace, "IMPL", "CompositeImplicitAutograd"
  32. )
  33. nvprim_autograd_impl = torch.library.Library(nvprim_namespace, "IMPL", "Autograd")
  34. nvprim_meta_impl = torch.library.Library(nvprim_namespace, "IMPL", "Meta")
  35. nvprim_names = [
  36. "abs",
  37. "acos",
  38. "asin",
  39. "atan",
  40. "atanh",
  41. "cos",
  42. "cosh",
  43. "clone",
  44. "bitwise_not",
  45. "ceil",
  46. "erf",
  47. "erfc",
  48. "exp",
  49. "expm1",
  50. "floor",
  51. "imag",
  52. "isfinite",
  53. "lgamma",
  54. "log",
  55. "log1p",
  56. "log2",
  57. "log10",
  58. "real",
  59. "reciprocal",
  60. "neg",
  61. "round",
  62. "rsqrt",
  63. "sign",
  64. "sin",
  65. "sinh",
  66. "sqrt",
  67. "tan",
  68. "tanh",
  69. "transpose",
  70. "trunc",
  71. "add",
  72. "atan2",
  73. "bitwise_and",
  74. "bitwise_or",
  75. "bitwise_xor",
  76. "div",
  77. "eq",
  78. "fmod",
  79. "ge",
  80. "gt",
  81. "le",
  82. "lt",
  83. "mul",
  84. "ne",
  85. "pow",
  86. "remainder",
  87. "sub",
  88. "squeeze",
  89. "view_of",
  90. "broadcast_in_dim",
  91. "where",
  92. "convert_element_type",
  93. "sum",
  94. "var",
  95. "amax",
  96. "amin",
  97. ]
  98. _nvfuser_impls: Dict[str, Any] = {}
  99. _nvfuser_unary_ops = {
  100. "abs",
  101. "acos",
  102. "asin",
  103. "atan",
  104. "atanh",
  105. "cos",
  106. "cosh",
  107. "bitwise_not",
  108. "ceil",
  109. "erf",
  110. "erfc",
  111. "exp",
  112. "expm1",
  113. "floor",
  114. "imag",
  115. "isfinite",
  116. "lgamma",
  117. "log",
  118. "log1p",
  119. "log2",
  120. "log10",
  121. "reciprocal",
  122. "neg",
  123. "real",
  124. "round",
  125. "rsqrt",
  126. "sign",
  127. "sin",
  128. "sinh",
  129. "sqrt",
  130. "tan",
  131. "tanh",
  132. "trunc",
  133. }
  134. def _assert_nvfuser_op_exists(fname: str):
  135. try:
  136. from nvfuser._C import FusionDefinition as fd # type: ignore[import]
  137. assert getattr(fd.Operators, fname)
  138. except ImportError:
  139. # Not all PyTorch builds have nvfuser
  140. pass
  141. for fname in _nvfuser_unary_ops:
  142. exec(
  143. f"""
  144. # Ensure that the nvfuser implementation exists
  145. _assert_nvfuser_op_exists("{fname}")
  146. def _{fname}_nvfuser(fd, a):
  147. return fd.ops.{fname}(a) # type: ignore[attr-defined]
  148. _nvfuser_impls["{fname}"] = _{fname}_nvfuser
  149. """
  150. )
  151. _nvfuser_binary_ops = {
  152. "add",
  153. "atan2",
  154. "bitwise_and",
  155. "bitwise_or",
  156. "bitwise_xor",
  157. "div",
  158. "eq",
  159. "fmod",
  160. "ge",
  161. "gt",
  162. "le",
  163. "lt",
  164. "mul",
  165. "ne",
  166. "pow",
  167. "remainder",
  168. "sub",
  169. }
  170. for fname in _nvfuser_binary_ops:
  171. exec(
  172. f"""
  173. # Ensure that the nvfuser implementation exists
  174. _assert_nvfuser_op_exists("{fname}")
  175. def _{fname}_nvfuser(fd, a, b):
  176. return fd.ops.{fname}(a, b) # type: ignore[attr-defined]
  177. _nvfuser_impls["{fname}"] = _{fname}_nvfuser
  178. """
  179. )
  180. _nvfuser_ternary_ops = {
  181. "where",
  182. }
  183. for fname in _nvfuser_ternary_ops:
  184. exec(
  185. f"""
  186. # Ensure that the nvfuser implementation exists
  187. _assert_nvfuser_op_exists("{fname}")
  188. def _{fname}_nvfuser(fd, a, b, c):
  189. return fd.ops.{fname}(a, b, c) # type: ignore[attr-defined]
  190. _nvfuser_impls["{fname}"] = _{fname}_nvfuser
  191. """
  192. )
  193. def _native_batch_norm_nvfuser(
  194. fd, input, weight, bias, running_mean, running_var, training, momentum, eps
  195. ):
  196. """
  197. if weight is None:
  198. weight = fd.define_null_tensor()
  199. if bias is None:
  200. bias = fd.define_null_tensor()
  201. if running_mean is None:
  202. running_mean = fd.define_null_tensor()
  203. if running_var is None:
  204. running_var = fd.define_null_tensor()
  205. """
  206. return fd.ops.batch_norm(
  207. input,
  208. weight,
  209. bias,
  210. running_mean,
  211. running_var,
  212. momentum,
  213. eps,
  214. training,
  215. )
  216. def _broadcast_in_dim_nvfuser(
  217. fd: Any,
  218. a: TensorLikeType,
  219. shape: ShapeType,
  220. broadcast_dimensions: ShapeType,
  221. ):
  222. return fd.ops.broadcast_in_dim(a, shape, broadcast_dimensions) # type: ignore[attr-defined]
  223. def _convert_element_type_nvfuser(fd: Any, a: TensorLikeType, dtype: torch.dtype):
  224. nvfuser_dtype = getnvFuserDtype(dtype)
  225. return fd.ops.cast(a, nvfuser_dtype) # type: ignore[attr-defined]
  226. def _transpose_nvfuser(fd, a, dims):
  227. return fd.ops.permute(a, dims) # type: ignore[attr-defined]
  228. def _squeeze_nvfuser(fd, a, a_shape, dimensions):
  229. for idx in sorted(dimensions, reverse=True):
  230. a = fd.ops.squeeze(a, a_shape, idx)
  231. a_shape = a_shape[:idx] + a_shape[idx + 1 :]
  232. return a
  233. def _view_of_nvfuser(fd, a):
  234. return fd.ops.set(a)
  235. def _view_nvfuser(
  236. fd,
  237. a,
  238. a_shape,
  239. new_shape,
  240. ):
  241. return fd.ops.view(a, a_shape, new_shape)
  242. def _sum_nvfuser(
  243. fd: Any,
  244. a: TensorLikeType,
  245. dims: DimsSequenceType,
  246. ):
  247. keep_dims = False
  248. from nvfuser._C import DataType # type: ignore[import]
  249. output_dtype = DataType.Null
  250. return fd.ops.sum(a, dims, keep_dims, output_dtype)
  251. def _var_nvfuser(
  252. fd: Any,
  253. a: TensorLikeType,
  254. dims: DimsSequenceType,
  255. *,
  256. correction: int,
  257. ):
  258. keep_dims = False
  259. return fd.ops.var(a, dims, correction, keep_dims)
  260. def _var_mean_nvfuser(
  261. fd: Any,
  262. a: TensorLikeType,
  263. dims: DimsSequenceType,
  264. unbiased: Optional[bool] = None,
  265. keepdim: bool = False,
  266. *,
  267. correction: int,
  268. ):
  269. # Unbiased arg shouldn't be set when this function is called
  270. assert unbiased is None
  271. # Ignore keepdim arg, because currently it's automatically converted into nvfuser's symbolic scalar
  272. # keepdim is handled by the reference implementation
  273. keepdim = False
  274. return fd.ops.var_mean(a, dims, correction, keepdim)
  275. def _rand_like_nvfuser(fd: Any, a: TensorLikeType):
  276. return fd.ops.rand_like(a)
  277. def _amax_nvfuser(
  278. fd: Any,
  279. a: TensorLikeType,
  280. dims: DimsSequenceType,
  281. ):
  282. keep_dims = False
  283. return fd.ops.max(a, dims, keep_dims)
  284. def _amin_nvfuser(
  285. fd: Any,
  286. a: TensorLikeType,
  287. dims: DimsSequenceType,
  288. ):
  289. keep_dims = False
  290. return fd.ops.min(a, dims, keep_dims)
  291. def _clone_nvfuser(fd: Any, input: TensorLikeType, *, memory_format=None):
  292. return fd.ops.set(input)
  293. def _full_nvfuser(
  294. fd: Any,
  295. shape: ShapeType,
  296. fill_value: NumberType,
  297. *,
  298. dtype: Optional[torch.dtype] = None,
  299. layout: Optional[torch.layout] = None,
  300. device: Optional[torch.device] = None,
  301. pin_memory: bool = False,
  302. requires_grad: bool = False,
  303. ):
  304. assert device != torch.device("cpu")
  305. assert layout is None or layout is torch.strided
  306. assert pin_memory is False
  307. assert requires_grad is False
  308. dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value))
  309. nvfuser_dtype = getnvFuserDtype(dtype)
  310. return fd.ops.full(shape, fill_value, nvfuser_dtype)
  311. _nvfuser_impls["native_batch_norm"] = _native_batch_norm_nvfuser
  312. _nvfuser_impls["broadcast_in_dim"] = _broadcast_in_dim_nvfuser
  313. _nvfuser_impls["convert_element_type"] = _convert_element_type_nvfuser
  314. _nvfuser_impls["clone"] = _clone_nvfuser
  315. _nvfuser_impls["transpose"] = _transpose_nvfuser
  316. _nvfuser_impls["squeeze"] = _squeeze_nvfuser
  317. _nvfuser_impls["view_of"] = _view_of_nvfuser
  318. _nvfuser_impls["view"] = _view_nvfuser
  319. _nvfuser_impls["rand_like"] = _rand_like_nvfuser
  320. _nvfuser_impls["sum"] = _sum_nvfuser
  321. _nvfuser_impls["var"] = _var_nvfuser
  322. _nvfuser_impls["var_mean"] = _var_mean_nvfuser
  323. _nvfuser_impls["amax"] = _amax_nvfuser
  324. _nvfuser_impls["amin"] = _amin_nvfuser
  325. _nvfuser_impls["full"] = _full_nvfuser
  326. def register_full():
  327. name = "full"
  328. nvprim.define(
  329. "full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, "
  330. + "bool? pin_memory=None, bool? requires_grad=None) -> Tensor"
  331. )
  332. def _meta_impl(
  333. size,
  334. fill_value,
  335. *,
  336. out=None,
  337. dtype=None,
  338. layout=None,
  339. device=None,
  340. pin_memory=False,
  341. requires_grad=False,
  342. ):
  343. strides = make_contiguous_strides_for(size)
  344. return torch._prims.TensorMeta(
  345. None,
  346. shape=size,
  347. strides=strides,
  348. dtype=dtype,
  349. device=device,
  350. )
  351. def _prim_impl(
  352. size,
  353. fill_value,
  354. *,
  355. out=None,
  356. dtype=None,
  357. layout=None,
  358. device=None,
  359. pin_memory=False,
  360. requires_grad=False,
  361. ):
  362. return torch.full(
  363. size,
  364. fill_value,
  365. out=out,
  366. dtype=dtype,
  367. layout=layout,
  368. device=device,
  369. pin_memory=pin_memory,
  370. requires_grad=requires_grad,
  371. )
  372. nvprim_impl.impl(name, _prim_impl)
  373. nvprim_meta_impl.impl(name, _meta_impl)
  374. prim_packet = getattr(torch._ops.ops.nvprims, name)
  375. prim = prim_packet.default
  376. nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
  377. for p in (prim_packet, prim):
  378. p.__doc__ = "Create a tensor with given size and filled with value"
  379. p.impl_nvfuser = _nvfuser_impls["full"]
  380. p.is_recomputable = _nvfuser_is_recomputable["full"]
  381. p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
  382. # functorch.compile.min_cut_rematerialization_partition accepts a list of
  383. # operators that can be recomputed in the backward pass. This list is used to
  384. # determine which operators can be recomputed. If an operator is not in this
  385. # list, it will not be recomputed.
  386. _nvfuser_is_recomputable: Dict[str, bool] = {
  387. # Reductions are not allowed to be recomputed
  388. "amax": False,
  389. "amin": False,
  390. "sum": False,
  391. "var": False,
  392. "var_mean": False,
  393. # Normalizations are not allowed to be recomputed
  394. "native_batch_norm": False,
  395. # Random ops are not allowed to be recomputed
  396. "rand_like": False,
  397. # Everything else is allowed to be recomputed
  398. "abs": True,
  399. "acos": True,
  400. "add": True,
  401. "asin": True,
  402. "atan": True,
  403. "atan2": True,
  404. "atanh": True,
  405. "bitwise_and": True,
  406. "bitwise_not": True,
  407. "bitwise_or": True,
  408. "bitwise_xor": True,
  409. "broadcast_in_dim": True,
  410. "ceil": True,
  411. "clone": True,
  412. "convert_element_type": True,
  413. "cos": True,
  414. "cosh": True,
  415. "div": True,
  416. "eq": True,
  417. "erf": True,
  418. "erfc": True,
  419. "exp": True,
  420. "expm1": True,
  421. "floor": True,
  422. "fmod": True,
  423. "full": True,
  424. "ge": True,
  425. "gt": True,
  426. "imag": True,
  427. "isfinite": True,
  428. "le": True,
  429. "lgamma": True,
  430. "log": True,
  431. "log10": True,
  432. "log1p": True,
  433. "log2": True,
  434. "lt": True,
  435. "mul": True,
  436. "ne": True,
  437. "neg": True,
  438. "pow": True,
  439. "real": True,
  440. "reciprocal": True,
  441. "remainder": True,
  442. "round": True,
  443. "rsqrt": True,
  444. "sign": True,
  445. "sin": True,
  446. "sinh": True,
  447. "sqrt": True,
  448. "squeeze": True,
  449. "sub": True,
  450. "tan": True,
  451. "tanh": True,
  452. "transpose": True,
  453. "trunc": True,
  454. "view": True,
  455. "view_of": True,
  456. "where": True,
  457. }
  458. def register_native_batch_norm():
  459. """This function is used to register the native_batch_norm function in torch.ops.nvprims module."""
  460. name = "native_batch_norm"
  461. nvprim.define(
  462. f"{name}(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, "
  463. + "bool training, float momentum, float eps)"
  464. + " -> (Tensor, Tensor, Tensor)"
  465. )
  466. def _prim_impl(
  467. input, weight, bias, running_mean, running_var, training, momentum, eps
  468. ):
  469. return torch.native_batch_norm(
  470. input, weight, bias, running_mean, running_var, training, momentum, eps
  471. )
  472. nvprim_impl.impl(name, _prim_impl)
  473. prim_packet = torch._ops.ops.nvprims.native_batch_norm
  474. prim = prim_packet.default
  475. def _native_batch_norm_ref(
  476. input: torch.Tensor,
  477. weight: Optional[torch.Tensor],
  478. bias: Optional[torch.Tensor],
  479. running_mean: Optional[torch.Tensor],
  480. running_var: Optional[torch.Tensor],
  481. training: bool,
  482. momentum: float,
  483. eps: float,
  484. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  485. if torch._prims_common.is_complex_dtype(input.dtype):
  486. raise NotImplementedError("Complex tensors are not supported")
  487. # note: BN only promotes input to dtype of weight/bias, but keeps the same output dtype
  488. result_dtype = input.dtype
  489. computation_dtype, _ = elementwise_dtypes(
  490. input,
  491. weight,
  492. bias,
  493. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
  494. )
  495. input_ = _maybe_convert_to_dtype(input, computation_dtype)
  496. output, mean, rstd = prim(
  497. input_, weight, bias, running_mean, running_var, training, momentum, eps
  498. )
  499. output_ = _maybe_convert_to_dtype(output, result_dtype) # type: ignore[arg-type]
  500. return (output_, mean, rstd) # type: ignore[return-value]
  501. def _native_batch_norm_autograd(
  502. input: torch.Tensor,
  503. weight: Optional[torch.Tensor],
  504. bias: Optional[torch.Tensor],
  505. running_mean: Optional[torch.Tensor],
  506. running_var: Optional[torch.Tensor],
  507. training: bool,
  508. momentum: float,
  509. eps: float,
  510. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  511. # This wrapper is needed to convert prims calls inside
  512. # _native_batch_norm_ref to nvprims calls
  513. from torch._prims.context import NvfuserPrimsMode
  514. with NvfuserPrimsMode():
  515. return backwards_not_supported(_native_batch_norm_ref)(
  516. input, weight, bias, running_mean, running_var, training, momentum, eps
  517. )
  518. nvprim_autograd_impl.impl(name, _native_batch_norm_autograd)
  519. for p in (prim_packet, prim):
  520. p.__doc__ = "Computes batch normalization."
  521. p.impl_nvfuser = _nvfuser_impls["native_batch_norm"]
  522. p.is_recomputable = _nvfuser_is_recomputable["native_batch_norm"]
  523. p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
  524. def register_rand_like():
  525. name = "rand_like"
  526. nvprim.define(
  527. "rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, "
  528. + "Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"
  529. )
  530. def _meta_rand_like(
  531. self,
  532. *,
  533. dtype=None,
  534. layout=None,
  535. device=None,
  536. pin_memory=None,
  537. memory_format=None,
  538. ):
  539. strides = make_contiguous_strides_for(self.shape)
  540. return torch._prims.TensorMeta(
  541. self,
  542. shape=self.shape,
  543. strides=strides,
  544. dtype=dtype,
  545. device=device,
  546. )
  547. def _prim_impl(
  548. self,
  549. *,
  550. dtype=None,
  551. layout=None,
  552. device=None,
  553. pin_memory=None,
  554. memory_format=None,
  555. ):
  556. return torch.rand_like(
  557. self,
  558. dtype=dtype,
  559. layout=layout,
  560. device=device,
  561. pin_memory=pin_memory,
  562. memory_format=memory_format,
  563. )
  564. nvprim_impl.impl(name, _prim_impl)
  565. nvprim_meta_impl.impl(name, _meta_rand_like)
  566. prim_packet = getattr(torch._ops.ops.nvprims, name)
  567. prim = prim_packet.default
  568. nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
  569. for p in (prim_packet, prim):
  570. p.__doc__ = "Computes rand_like"
  571. p.impl_nvfuser = _nvfuser_impls["rand_like"]
  572. p.is_recomputable = _nvfuser_is_recomputable["rand_like"]
  573. p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
  574. def register_var_mean():
  575. """This function is used to register the var_mean function in torch.ops.nvprims module."""
  576. name = "var_mean.main"
  577. # This overload must be default for correct dispatching of var_mean(Tensor, bool)
  578. nvprim.define("var_mean(Tensor inp, bool unbiased) -> (Tensor, Tensor)")
  579. # This signature tries to combine several overloads of the torch.var_mean function into one overload.
  580. nvprim.define(
  581. f"{name}(Tensor inp, int[1]? dim=None, bool? unbiased=None, bool keepdim=False, *, int? correction=None)"
  582. + " -> (Tensor, Tensor)"
  583. )
  584. # This function is used for device="meta" Tensors.
  585. def _meta_var_mean(inp, dim=None, unbiased=None, keepdim=False, *, correction=None):
  586. if torch._prims_common.is_complex_dtype(inp.dtype):
  587. output_dtype = torch._prims_common.corresponding_real_dtype(inp.dtype)
  588. else:
  589. output_dtype = inp.dtype
  590. var = torch._prims._reduction_meta(inp, dim, output_dtype=output_dtype)
  591. mean = torch._prims._reduction_meta(inp, dim, output_dtype=inp.dtype)
  592. if keepdim:
  593. output_shape = [
  594. inp.shape[i] if i not in dim else 1 for i in range(inp.ndim)
  595. ]
  596. broadcast_dims = [i for i in range(inp.ndim) if i not in dim]
  597. var = torch._ops.ops.nvprims.broadcast_in_dim(
  598. var, output_shape, broadcast_dims
  599. )
  600. mean = torch._ops.ops.nvprims.broadcast_in_dim(
  601. mean, output_shape, broadcast_dims
  602. )
  603. return (var, mean)
  604. # This function is used under _AutoDispatchBelowAutograd context
  605. def _prim_impl(inp, dim=None, unbiased=None, keepdim=False, *, correction=None):
  606. correction = torch._prims_common.set_correction(unbiased, correction)
  607. return torch.var_mean(inp, dim, correction=correction, keepdim=keepdim)
  608. nvprim_impl.impl(name, _prim_impl)
  609. nvprim_meta_impl.impl(name, _meta_var_mean)
  610. prim_packet = torch._ops.ops.nvprims.var_mean
  611. prim = prim_packet.main
  612. def _unbiased_overload_impl(inp, unbiased):
  613. return prim(inp, dim=None, unbiased=unbiased)
  614. nvprim_implicit_impl.impl("var_mean", _unbiased_overload_impl)
  615. @elementwise_type_promotion_wrapper(
  616. type_promoting_args=("a",),
  617. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
  618. )
  619. def _var_mean_ref(a, dim=None, unbiased=None, keepdim=False, *, correction=None):
  620. correction = torch._prims_common.set_correction(unbiased, correction)
  621. # reduces over all dimensions if dim=() is passed
  622. if dim == () or dim == []:
  623. dim = None
  624. dim = torch._prims_common.reduction_dims(a.shape, dim)
  625. # For complex tensors eager computes the variance as the sum of variances of
  626. # the real and imaginary parts
  627. # TODO: Creating a complex tensor from real and imaginary parts is not supported
  628. if torch._prims_common.is_complex_dtype(a.dtype):
  629. raise NotImplementedError("Complex tensors are not supported")
  630. var_mean = prim(a, dim, correction=correction)
  631. if keepdim:
  632. output_shape = [a.shape[i] if i not in dim else 1 for i in range(a.ndim)]
  633. broadcast_dims = [i for i in range(a.ndim) if i not in dim]
  634. var, mean = var_mean
  635. var = torch._ops.ops.nvprims.broadcast_in_dim(
  636. var, output_shape, broadcast_dims
  637. )
  638. mean = torch._ops.ops.nvprims.broadcast_in_dim(
  639. mean, output_shape, broadcast_dims
  640. )
  641. var_mean = (var, mean)
  642. return var_mean
  643. def _var_mean_autograd(
  644. a, dim=None, unbiased=None, keepdim=False, *, correction=None
  645. ):
  646. # This wrapper is needed to convert prims calls inside
  647. # elementwise_type_promotion_wrapper to nvprims calls
  648. from torch._prims.context import NvfuserPrimsMode
  649. with NvfuserPrimsMode():
  650. return backwards_not_supported(_var_mean_ref)(
  651. a, dim, unbiased, keepdim, correction=correction
  652. )
  653. nvprim_autograd_impl.impl(name, _var_mean_autograd)
  654. for p in (prim_packet, prim):
  655. p.__doc__ = "Computes the variance and mean of x over the list of dimensions specified in the dim argument"
  656. p.impl_nvfuser = _nvfuser_impls["var_mean"]
  657. p.is_recomputable = _nvfuser_is_recomputable["var_mean"]
  658. p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
  659. def _nvprims_view_impl_aten(a, original_shape, new_shape):
  660. return a.reshape(new_shape)
  661. def register_view():
  662. """This function is used to register the view function in torch.ops.view module."""
  663. # View is implemented as a decomposition into prims.split_dim,
  664. # prims.collapse_dim, and prims.reshape, but we would like to intercept
  665. # non-decomposed view for now
  666. name = "view"
  667. nvprim.define("view(Tensor inp, SymInt[] original_shape, SymInt[] shape) -> Tensor")
  668. nvprim.define("view.shape(Tensor inp, SymInt[] shape) -> Tensor")
  669. # This function is used under _AutoDispatchBelowAutograd context
  670. def _prim_impl(a, original_shape, new_shape):
  671. return a.reshape(new_shape)
  672. nvprim_impl.impl(name, _prim_impl)
  673. prim_packet = torch._ops.ops.nvprims.view
  674. prim = prim_packet.default
  675. def _view_no_original_shape_overload_impl(a, shape):
  676. if list(a.shape) == list(shape):
  677. return torch.ops.nvprims.view_of(a)
  678. return torch.ops.nvprims.view.default(a, a.shape, shape)
  679. nvprim_implicit_impl.impl("view.shape", _view_no_original_shape_overload_impl)
  680. nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
  681. for p in (prim_packet, prim):
  682. p.__doc__ = "Creates a tensor with the specified shape containing a copy of the data in a."
  683. p.impl_nvfuser = _nvfuser_impls["view"]
  684. p.is_recomputable = _nvfuser_is_recomputable["view"]
  685. p.return_type = torch._prims_common.RETURN_TYPE.VIEW # type: ignore[attr-defined]
  686. p.impl_aten = _nvprims_view_impl_aten
  687. def register_nvprims():
  688. """Registers all nvFuser primitives in the torch.ops.nvprims module."""
  689. register_var_mean()
  690. register_view()
  691. register_native_batch_norm()
  692. register_rand_like()
  693. register_full()
  694. for name in nvprim_names:
  695. main_prim = getattr(torch._ops.ops.prims, name)
  696. nvprim.define(main_prim.schema)
  697. nvprim_impl.impl(name, main_prim.prim_impl)
  698. nvprim_meta_impl.impl(name, main_prim.prim_meta_impl)
  699. prim_packet = getattr(torch._ops.ops.nvprims, name)
  700. prim = prim_packet.default
  701. nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
  702. for p in (prim_packet, prim):
  703. p.__doc__ = main_prim.__doc__
  704. p.impl_nvfuser = _nvfuser_impls[name]
  705. p.is_recomputable = _nvfuser_is_recomputable.get(name, False)
  706. p.return_type = main_prim.return_type # type: ignore[attr-defined]
  707. p.impl_aten = main_prim.impl_aten