autograd.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663
  1. import copy
  2. import re
  3. from dataclasses import dataclass
  4. from typing import Dict, List, Match, Optional, Sequence, Set, Tuple
  5. from torchgen.api import cpp
  6. from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT
  7. from torchgen.model import (
  8. FunctionSchema,
  9. NativeFunction,
  10. NativeFunctionsViewGroup,
  11. SchemaKind,
  12. Type,
  13. )
  14. from torchgen.utils import IDENT_REGEX
  15. # Represents a saved attribute involved in backward calculation.
  16. # Note that it can be a derived property of an input argument, e.g.:
  17. # we could save `other.scalar_type()` instead of the entire `other` tensor.
  18. @dataclass(frozen=True)
  19. class SavedAttribute:
  20. # The NamedCType holds the updated name and cpp type of the attribute
  21. # for the name, Suffix is appended if it's derived property, e.g.: `other_scalar_type`
  22. nctype: NamedCType
  23. # The expression to read the derived property at save time, e.g.:
  24. # `other.scalar_type()`.
  25. expr: str
  26. # Represents a backward formula that calculates derivatives for one
  27. # or more tensors.
  28. @dataclass(frozen=True)
  29. class Derivative:
  30. # The formula string (legit C++ expression).
  31. # Note that expressions against input arguments have been replaced with the
  32. # corresponding saved attributes.
  33. # E.g.:
  34. # raw formula: `mul_tensor_backward(grad, self, other.scalar_type())`
  35. # here: `mul_tensor_backward(grad, self, other_scalar_type)`
  36. formula: str
  37. # The formula string before input argument replacement
  38. original_formula: str
  39. # Names of the arguments for which this formula calculates derivatives.
  40. var_names: Tuple[str, ...]
  41. # Saved inputs that are referenced by the formula.
  42. saved_inputs: Tuple[SavedAttribute, ...]
  43. # Saved outputs that are referenced by the formula.
  44. saved_outputs: Tuple[SavedAttribute, ...]
  45. # Gradients that are referenced by name in the formula.
  46. named_gradients: Set[str]
  47. # Represents a forward formula that calculates forward derivatives
  48. # for one tensor.
  49. @dataclass(frozen=True)
  50. class ForwardDerivative:
  51. # The formula string (legit C++ expression).
  52. # Note that special keywords such as "linear" or "element_wise" have been
  53. # replaced by the automatically generated formula.
  54. formula: str
  55. # Name of the output arguments for which this formula calculates forward
  56. # derivatives
  57. var_names: Tuple[str, ...]
  58. # Type of the output arguments for which this formula calculates forward
  59. # derivatives
  60. var_types: Tuple[Type, ...]
  61. # Inputs for which the forward derivatives are required for this formula
  62. required_inputs_fw_grad: Optional[Tuple[str, ...]]
  63. # Inputs for which the primal is required for this formula
  64. required_inputs_primal: Optional[Tuple[str, ...]]
  65. # Flag to specify if this formula requires the original value of self
  66. # This is only used by inplace operations
  67. required_original_self_value: bool
  68. # If this formula is specified in derivatives.yaml or if we are re-using the
  69. # out of place formula for inplace
  70. is_reusing_outplace_formula: bool
  71. # Represents differentiability info for a NativeFunction.
  72. @dataclass(frozen=True)
  73. class DifferentiabilityInfo:
  74. # The base name read from derivatives.yaml.
  75. name: str
  76. # The matching native function.
  77. #
  78. # There can be multiple NativeFunction having the same base name:
  79. # - different overloads with different types of input arguments;
  80. # - in-place/out/functional variants of the same function;
  81. #
  82. # We first use the schema string (under the 'name' key) in derivatives.yaml
  83. # to find the NativeFunction having the same schema string.
  84. # Then we find the in-place/out/functional variants of the matching function.
  85. # Among these variants, we choose the one having the same name as the
  86. # derivatives.yaml entry. If there is no exact match, then we choose the
  87. # in-place variant.
  88. # TODO: maybe the logic to search for all variants is no longer necessary?
  89. func: NativeFunction
  90. # The name of the generated autograd function.
  91. # It's set only if we will calculate a derivative, i.e.
  92. # 'args_with_derivatives' is not empty.
  93. op: Optional[str]
  94. # The derivatives formulae for this function.
  95. # Note that the length of this sequence is the number of differentiable inputs
  96. derivatives: Sequence[Derivative]
  97. # The forward derivatives formulae for this function.
  98. # Note that the length of this sequence is the number of differentiable outputs
  99. forward_derivatives: Sequence[ForwardDerivative]
  100. # The union of 'saved_inputs' of all 'derivatives'.
  101. all_saved_inputs: Sequence[SavedAttribute]
  102. # The union of 'saved_outputs' of all 'derivatives'.
  103. all_saved_outputs: Sequence[SavedAttribute]
  104. # All named gradients that are available for use, in the same
  105. # order as in the grads vector.
  106. available_named_gradients: Sequence[str]
  107. # The named gradients that are used in any of the derivatives.
  108. # Invariant: all(name in available_named_gradients for name in used_named_gradients)
  109. used_named_gradients: Set[str]
  110. # The function's input arguments for which it calculates derivatives.
  111. # It's the union of 'var_names' of all 'derivatives', sorted by the
  112. # argument order in the function schema.
  113. args_with_derivatives: Sequence[Binding]
  114. # Names of arguments whose derivative formula is 'non_differentiable'.
  115. non_differentiable_arg_names: Sequence[str]
  116. # Raw data read from derivatives.yaml.
  117. output_differentiability: Optional[List[bool]]
  118. # output_differentiability in derivatives.yaml can be a list of
  119. # conditions that express if the output is differentiable. In this case,
  120. # the number of conditions must match the number of outputs
  121. # (NB: we only support one condition right now).
  122. # output_differentiability gets populated with True for each condition,
  123. # while output_differentiability_conditions gets populated with the conditions
  124. output_differentiability_conditions: Optional[List[str]]
  125. @property
  126. def has_derivatives(self) -> bool:
  127. return len(self.args_with_derivatives) > 0
  128. # Generates a new DifferentiabilityInfo using the exact same set of derivative information,
  129. # but with a new operator name.
  130. # This is used when generating "copy" variants of view ops,
  131. # which are able to use the exact same derivative formula as the original view op
  132. # See Note [Codegen'd {view}_copy Operators]
  133. def create_view_copy_from_view_derivative(
  134. self, g: NativeFunctionsViewGroup
  135. ) -> Optional["DifferentiabilityInfo"]:
  136. if g.view_copy is None:
  137. return None
  138. f = g.view_copy
  139. name_split_by_period = self.name.split(".", maxsplit=2)
  140. # Append a "_copy" to the base name of the operator (but keep the overload name the same)
  141. view_copy_name = f"{name_split_by_period[0]}_copy." + ".".join(
  142. name_split_by_period[1:]
  143. )
  144. view_copy_op_name = None if self.op is None else f"{self.op}_copy"
  145. return DifferentiabilityInfo(
  146. # Use the "_copy" version of name/func/op
  147. name=view_copy_name,
  148. func=f,
  149. op=view_copy_op_name,
  150. # But keep all derivative info the same
  151. derivatives=self.derivatives,
  152. forward_derivatives=self.forward_derivatives,
  153. all_saved_inputs=self.all_saved_inputs,
  154. all_saved_outputs=self.all_saved_outputs,
  155. available_named_gradients=self.available_named_gradients,
  156. used_named_gradients=self.used_named_gradients,
  157. args_with_derivatives=self.args_with_derivatives,
  158. non_differentiable_arg_names=self.non_differentiable_arg_names,
  159. output_differentiability=self.output_differentiability,
  160. output_differentiability_conditions=self.output_differentiability_conditions,
  161. )
  162. def uses_ident(info: Optional[DifferentiabilityInfo], ident: str) -> bool:
  163. if info is None:
  164. return False
  165. for derivative in info.derivatives:
  166. formula = derivative.formula
  167. if re.search(IDENT_REGEX.format(ident), formula):
  168. return True
  169. return False
  170. def uses_retain_variables(info: Optional[DifferentiabilityInfo]) -> bool:
  171. return uses_ident(info, "retain_variables")
  172. def uses_single_grad(info: Optional[DifferentiabilityInfo]) -> bool:
  173. return uses_ident(info, "grad")
  174. # Represents a differentiable `Argument`.
  175. # How is it different from the `Argument` type?
  176. # - It's processed Arguments which are differentiable and only used in the
  177. # context of the autograd codegen;
  178. # - It can represent SelfArgument or regular Argument but not TensorOptionsArgument;
  179. @dataclass(frozen=True)
  180. class DifferentiableInput:
  181. name: str
  182. type: Type
  183. # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
  184. cpp_type: str
  185. # Represents a differentiable `Return`.
  186. # How it it different from the `Return` type?
  187. # - The name in `Return` is optional. Here it is always populated using the same
  188. # `cpp.return_names()` method.
  189. # TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant?
  190. # - It's processed Returns which are differentiable, in compliance with the
  191. # `output_differentiability` field defined in derivatives.yaml (if specified),
  192. # and are only used in the context of the autograd codegen;
  193. @dataclass(frozen=True)
  194. class DifferentiableOutput:
  195. name: str
  196. type: Type
  197. # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
  198. cpp_type: str
  199. @dataclass(frozen=True)
  200. class NativeFunctionWithDifferentiabilityInfo:
  201. func: NativeFunction
  202. info: Optional[Dict[str, DifferentiabilityInfo]]
  203. fw_derivatives: Optional[Dict[str, Sequence[ForwardDerivative]]]
  204. # TODO: Update comment below since it is out of date.
  205. def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str:
  206. """How are we going to call the underlying implementation of a
  207. declaration? There are two strategies:
  208. - use_derived: we want to call the implementation on CPUDoubleType
  209. (or a similar, derived Type instance). Because these derived
  210. instances deal in Tensors, not Variables (it's a completely different
  211. object, so it doesn't dispatch back to VariableType), code on
  212. this dispatch path needs to wrap/unwrap tensors. If the
  213. derived implementation takes and returns tensors, the
  214. implementation is usually differentiable (although we also use
  215. the derived dispatch path for non-differentiable functions
  216. that we still want to dispatch on the derived Type instance;
  217. e.g., size())
  218. - use_type: we want to call the implementation on Type, because
  219. it is implemented concretely, and the functions it invokes will
  220. get dispatched back to VariableType (which will ensure that they
  221. are differentiable.)
  222. """
  223. # fn is derived as long as any of its per-key differentiability infos
  224. # has_derivatives. dispatch_strategy() is used to guard generation of fns in VariableType
  225. # and ADInplaceOrViewType. We want to generate these functions as long as a
  226. # derivative is defined for ANY dispatch key.
  227. if fn.func.is_abstract or (
  228. fn.info is not None and any(info.has_derivatives for info in fn.info.values())
  229. ):
  230. # If the function is abstract (not implemented on at::Type), we must
  231. # call the implementation on the derived type with unpacked tensors.
  232. # If the function has a derivative specified and is concrete, we could
  233. # call either implementation. We prefer the calling the derived
  234. # type's implementation with unpacked tensors because it is more
  235. # performant in some cases: any internal calls to other ATen functions
  236. # won't have the history tracked.
  237. # If the function has a type dispatched argument (i.e. is a factory),
  238. # we prefer calling the derived type's implementation both because it is
  239. # more performant and to ensure factory functions return tensors with _version
  240. # of 0 (probably not strictly necessary, but nice to have to keeps versions simple
  241. # to understand.
  242. return "use_derived"
  243. else:
  244. # If the function is concrete (we don't have to override it) and we
  245. # didn't declare it in derivatives.yaml, we'll assume that it is
  246. # actually implemented out of differentiable functions. (This
  247. # assumption might not hold, but then you'll see gradcheck fail.)
  248. return "use_type"
  249. def match_differentiability_info(
  250. native_functions: List[NativeFunction],
  251. differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
  252. ) -> List[NativeFunctionWithDifferentiabilityInfo]:
  253. """Sets the "derivative" key on declarations to matching autograd function
  254. In-place functions will use the out-of-place derivative definition if there
  255. is no in-place specific derivative.
  256. """
  257. functional_info_by_signature = {
  258. schema.signature(strip_default=True): info_dict
  259. for schema, info_dict in differentiability_infos.items()
  260. if schema.kind() == SchemaKind.functional
  261. }
  262. non_functional_info_by_signature = {
  263. schema.signature(strip_default=True): info_dict
  264. for schema, info_dict in differentiability_infos.items()
  265. if schema.kind() != SchemaKind.functional
  266. }
  267. def find_info(
  268. f: NativeFunction,
  269. ) -> Tuple[Optional[Dict[str, DifferentiabilityInfo]], bool]:
  270. # Don't bother matching info to generated out= variants
  271. if "generated" in f.tags and f.func.kind() == SchemaKind.out:
  272. return None, False
  273. # (1) Check for an exact match
  274. if f.func in differentiability_infos:
  275. return differentiability_infos[f.func], True
  276. # (2) If no exact match, check if the out-of-place variant
  277. # of this operator has a match.
  278. # i.e mul() for mul_() or mul_out()
  279. f_sig = f.func.signature(strip_default=True)
  280. if f_sig in functional_info_by_signature:
  281. return functional_info_by_signature[f_sig], False
  282. # (3) Some operators have a derivative explicitly defined for the mutable
  283. # variant, but get a code-generated out-of-place variant which does *not*
  284. # come with a derivative formula.
  285. # For the generated out-of-place variant, use the mutable variant's formula
  286. # if it exists.
  287. if "generated" in f.tags and f_sig in non_functional_info_by_signature:
  288. info_dict = non_functional_info_by_signature[f_sig]
  289. # See https://github.com/pytorch/pytorch/pull/76320/files#r874816389
  290. assert not any(
  291. any("self" in str(inpt.nctype.name) for inpt in info.all_saved_inputs)
  292. for info in info_dict.values()
  293. ), f"""\
  294. Attempted to convert a derivative formula for a mutable operator
  295. to be used by automatically by its functional variant ("{str(f.func)}").
  296. this is not currently supported (we'd need to fix up the formula in the codegen)."""
  297. return info_dict, False
  298. # (4) Generate derivative information of unary foreach functions if none is defined in `derivatives.yaml`
  299. base_op_name = f.func.name.name
  300. if (
  301. base_op_name.base.startswith("_foreach")
  302. and not base_op_name.inplace
  303. and len(f.func.arguments.post_self_positional) == 0
  304. ):
  305. ref_native_op_name = base_op_name.base.split("_foreach_")[-1]
  306. for function_schema in functional_info_by_signature:
  307. if (
  308. function_schema.name.name.base == ref_native_op_name
  309. and not function_schema.name.name.inplace
  310. ):
  311. all_saved_inputs = []
  312. all_saved_outputs = []
  313. diff_info_dict = copy.deepcopy(
  314. differentiability_infos[function_schema]
  315. )
  316. diff_info = diff_info_dict["Default"]
  317. modified_derivative_formulas = []
  318. for derivative in diff_info.derivatives:
  319. saved_inputs = []
  320. saved_outputs = []
  321. modified_formula = (
  322. derivative.formula.replace("grad", "grads[i]")
  323. .replace("self", "self[i]")
  324. .replace("result", "result[i]")
  325. )
  326. if "self" in modified_formula:
  327. saved_inputs.append(
  328. SavedAttribute(
  329. nctype=NamedCType(
  330. name="self", type=BaseCType(tensorListT)
  331. ),
  332. expr="self",
  333. )
  334. )
  335. all_saved_inputs.append(saved_inputs[-1])
  336. if "result" in modified_formula:
  337. saved_outputs.append(
  338. SavedAttribute(
  339. nctype=NamedCType(
  340. name="result", type=BaseCType(tensorListT)
  341. ),
  342. expr="result",
  343. )
  344. )
  345. all_saved_outputs.append(saved_outputs[-1])
  346. modified_derivative = Derivative(
  347. formula=modified_formula,
  348. original_formula=derivative.original_formula,
  349. var_names=("self",),
  350. saved_inputs=tuple(saved_inputs),
  351. saved_outputs=tuple(saved_outputs),
  352. named_gradients=set(),
  353. )
  354. modified_derivative_formulas.append(modified_derivative)
  355. assert f.func.arguments.self_arg is not None
  356. diff_info = DifferentiabilityInfo(
  357. name=base_op_name.base,
  358. func=f,
  359. op=f"Foreach{diff_info.op}",
  360. derivatives=modified_derivative_formulas,
  361. forward_derivatives=[],
  362. all_saved_inputs=tuple(set(all_saved_inputs)),
  363. all_saved_outputs=tuple(set(all_saved_outputs)),
  364. available_named_gradients=(),
  365. used_named_gradients=set(),
  366. args_with_derivatives=[
  367. Binding(
  368. name="self",
  369. nctype=NamedCType(
  370. name="self", type=BaseCType(tensorListT)
  371. ),
  372. argument=f.func.arguments.self_arg.argument,
  373. default=None,
  374. )
  375. ],
  376. non_differentiable_arg_names=[],
  377. output_differentiability=None,
  378. output_differentiability_conditions=None,
  379. )
  380. diff_info_dict["Default"] = diff_info
  381. if f.func not in differentiability_infos:
  382. differentiability_infos[f.func] = diff_info_dict
  383. functional_info_by_signature[f.func] = diff_info_dict
  384. return diff_info_dict, True
  385. return None, False
  386. result: List[NativeFunctionWithDifferentiabilityInfo] = []
  387. for f in native_functions:
  388. info_dict, is_exact_match = find_info(f)
  389. # Currently, the '.strides()' to 'strides_or_error' replacement does not support
  390. # 'self' derivatives of an inplace function, so we must check for this case.
  391. if f.func.kind() == SchemaKind.inplace and (info_dict is not None):
  392. for info in info_dict.values():
  393. for derivative in info.derivatives:
  394. if "self" in derivative.var_names:
  395. for saved_input in derivative.saved_inputs:
  396. assert "strides_or_error" not in saved_input.expr, (
  397. "Calling '.strides()' in the 'self' derivative formula of an "
  398. f"in-place function is not supported: {f.func}"
  399. )
  400. if not info_dict:
  401. result.append(
  402. NativeFunctionWithDifferentiabilityInfo(
  403. func=f, info=None, fw_derivatives=None
  404. )
  405. )
  406. continue
  407. fw_derivative_dict: Dict[str, Sequence[ForwardDerivative]] = {}
  408. for key, info in info_dict.items():
  409. if not info.forward_derivatives:
  410. fw_derivative_dict[key] = []
  411. continue
  412. forward_derivatives = info.forward_derivatives
  413. # For functions that have a single def for out-of-place and inplace (like abs())
  414. if f.func.kind() == SchemaKind.inplace:
  415. # For inplace functions there is a little bit of work to do:
  416. # 1) Validate the formula and make sure the input that is modified in not used:
  417. # - If there is a formula for the inplace variant of the function (is_exact_match == True) then
  418. # we make sure that the original value of the input that is being modified inplace (self_p) is
  419. # not used in the formula. Note that the formula can use "original_self_p" here and that would
  420. # trigger a clone of the original input.
  421. # - If we are re-using the out of place formula (is_exact_match == False) then we replace every
  422. # occurrence of self_p and self_t by original_self_p and original_self_t. These will be
  423. # populated by cloned version of the original input (either the clone done by the backward AD
  424. # logic if self is also used in a backward formula or a special clone that we add).
  425. # 2) At this point, there cannot be a self_p in the formula.
  426. # 3) Change "result" into "self_p" as by design, in the inplace function codegen, the result is
  427. # simply called self (as it is modified inplace).
  428. # 4) Update the required primals data in case it used to contain "result" but should now contain
  429. # "self"
  430. # 5) If it is not an exact match, the user formula is not modifying the existing forward grad
  431. # inplace as it should. So add some code that makes sure that we do so if the forward grad
  432. # already exists.
  433. assert (
  434. len(info.forward_derivatives) == 1
  435. ) # Only single output inplace should exist
  436. fw_info = info.forward_derivatives[0]
  437. formula = fw_info.formula
  438. def replace_self_with_original_self(formula: str, postfix: str) -> str:
  439. def repl(m: Match[str]) -> str:
  440. return f"{m.group(1)}original_self{postfix}{m.group(2)}"
  441. return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula)
  442. if re.search(IDENT_REGEX.format("self_p"), formula):
  443. if is_exact_match:
  444. # For manually defined formulas, don't allow the original value to be used
  445. raise RuntimeError(
  446. f'The formula for "{f.func.name}" is using the original value of self '
  447. "that is being modified inplace. This would lead to wrong forward gradients. "
  448. 'Please use "result" in the formula only.'
  449. )
  450. else:
  451. # When the original formula is out of place, we save a clone of the primal
  452. # value to be able to access this value if needed
  453. # replace "self_p"/"self_t" from the formula by "original_self_p"/"original_self_t"
  454. formula = replace_self_with_original_self(formula, "_p")
  455. formula = replace_self_with_original_self(formula, "_t")
  456. # replace "result" from the formula by "self_p"
  457. def repl(m: Match[str]) -> str:
  458. return f"{m.group(1)}self_p{m.group(2)}"
  459. formula = re.sub(IDENT_REGEX.format("result"), repl, formula)
  460. required_primals = fw_info.required_inputs_primal
  461. if re.search(IDENT_REGEX.format("self_p"), formula):
  462. required_primals = (
  463. required_primals + ("self",) if required_primals else ("self",)
  464. )
  465. if not is_exact_match:
  466. # NOTE [In-place forward AD formula Optimization]
  467. #
  468. # This optimization transforms the formula to directly do inplace, i.e.
  469. # instead of self_t.copy_(self_t.op()) we do self_t.op_() when the following are met:
  470. #
  471. # 1) the formula satisfies the pattern: "self_t.op(*args)"
  472. # 2) "op" in (1) needs to be the same as the op the derivative is for
  473. #
  474. # (2) may seem too strict, but currently the only ops that satisfy (1) also satisfy (2)
  475. # If there is a need, we can relax (2) to allow any op that has an in-place variant
  476. is_single_method_on_self_t = False
  477. directly_do_inplace = False
  478. op_name: Optional[str] = None
  479. between_parens: Optional[str] = None
  480. match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula)
  481. if match:
  482. op_name, between_parens = match.group(1), match.group(2)
  483. # We want to...
  484. # Match: self_t.op1(other_p.op2(arg))
  485. # Avoid: self_t.op1(args) + self_t.op2(args)
  486. # Avoid: self_t.op1(other_p.op2(arg)) + self_t.op2(args)
  487. def check_parens_nest_level_gt_zero(s: str) -> bool:
  488. level = 1
  489. for ch in s:
  490. if ch == ")":
  491. level -= 1
  492. if level == 0:
  493. return False
  494. if ch == "(":
  495. level += 1
  496. return True
  497. is_single_method_on_self_t = check_parens_nest_level_gt_zero(
  498. between_parens
  499. )
  500. directly_do_inplace = (
  501. is_single_method_on_self_t and op_name == info.name
  502. )
  503. if directly_do_inplace:
  504. assert op_name is not None
  505. assert between_parens is not None
  506. formula = f"self_t_raw.defined() ? self_t_raw.{op_name}_({between_parens}) : {formula}"
  507. else:
  508. # Make sure that the forward grad is modified inplace when the original formula
  509. # is out of place
  510. formula = f"self_t_raw.defined() ? self_t_raw.copy_({formula}) : {formula}"
  511. required_original_self_value = bool(
  512. re.search(IDENT_REGEX.format("original_self_p"), formula)
  513. ) or bool(re.search(IDENT_REGEX.format("original_self_t"), formula))
  514. forward_derivatives = [
  515. ForwardDerivative(
  516. formula=formula,
  517. var_names=("self",),
  518. var_types=fw_info.var_types,
  519. required_inputs_fw_grad=fw_info.required_inputs_fw_grad,
  520. required_inputs_primal=required_primals,
  521. required_original_self_value=required_original_self_value,
  522. is_reusing_outplace_formula=not is_exact_match,
  523. ),
  524. ]
  525. fw_derivative_dict[key] = forward_derivatives
  526. result.append(
  527. NativeFunctionWithDifferentiabilityInfo(
  528. func=f, info=info_dict, fw_derivatives=fw_derivative_dict
  529. )
  530. )
  531. return result
  532. def is_differentiable(
  533. name: str, type: Type, info: Optional[DifferentiabilityInfo]
  534. ) -> bool:
  535. return type.is_tensor_like() and (
  536. info is None or name not in info.non_differentiable_arg_names
  537. )
  538. def gen_differentiable_outputs(
  539. fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
  540. ) -> List[DifferentiableOutput]:
  541. f = fn.func
  542. info = fn.info[key] if fn.info else None
  543. outputs: List[DifferentiableOutput] = [
  544. DifferentiableOutput(
  545. name=name,
  546. type=ret.type,
  547. cpp_type=cpp.return_type(ret, symint=True).cpp_type(),
  548. )
  549. for name, ret in zip(cpp.return_names(f), f.func.returns)
  550. ]
  551. output_differentiability = info.output_differentiability if info else None
  552. if output_differentiability is not None:
  553. if len(output_differentiability) != len(outputs):
  554. raise RuntimeError(
  555. f"The length of output_differentiability ({len(output_differentiability)}), "
  556. f"does not match the number of outputs ({len(outputs)})."
  557. )
  558. differentiable_outputs: List[DifferentiableOutput] = []
  559. if False in output_differentiability and f.func.kind() == SchemaKind.inplace:
  560. raise RuntimeError(
  561. "output_differentiability=False for inplace operation (version_counter won't get updated)"
  562. )
  563. for differentiable, output in zip(output_differentiability, outputs):
  564. if differentiable:
  565. differentiable_outputs.append(output)
  566. return differentiable_outputs
  567. candidate_differentiable_outputs = list(
  568. filter(lambda r: is_differentiable(r.name, r.type, info), outputs)
  569. )
  570. if uses_single_grad(info):
  571. return candidate_differentiable_outputs[:1]
  572. else:
  573. return candidate_differentiable_outputs